Skip to content

Commit

Permalink
Add a minimize_with_timeout wrapper for scipy.optimize.minimize (#1403)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1403

Unfortunately, scipy does not allow timing out the optimization based on wall time. This diff implements a lightweight wrapper around `scipy.optimize.minimize` to achieve this.

The new `minimize_with_timeout` method calls `scipy.optimize.minimize` with all arguments forwarded verbatim. The only difference is that if provided a `timeout_sec` kwarg, it automatically stops the optimization after the timeout is reached.

Internally, this is achieved by automatically constructing a callback method that is injected to the `scipy.optimize.minimize` call that keeps track of the runtime and is used to extract the value of the optimization variables at the current iteration.

Differential Revision: D39529835

fbshipit-source-id: 50f773bfa2fb5b50ab9c6d724d6fc4a67e4f6a6e
  • Loading branch information
Balandat authored and facebook-github-bot committed Dec 28, 2022
1 parent 02ec731 commit 171f5bc
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 0 deletions.
2 changes: 2 additions & 0 deletions botorch/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CandidateGenerationError,
InputDataError,
ModelFittingError,
OptimizationTimeoutError,
UnsupportedError,
)
from botorch.exceptions.warnings import (
Expand All @@ -35,6 +36,7 @@
"BadInitialCandidatesWarning",
"CandidateGenerationError",
"ModelFittingError",
"OptimizationTimeoutError",
"OptimizationWarning",
"SamplingWarning",
"UnsupportedError",
Expand Down
23 changes: 23 additions & 0 deletions botorch/exceptions/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
Botorch Errors.
"""

from typing import Any

import numpy as np


class BotorchError(Exception):
r"""Base botorch exception."""
Expand Down Expand Up @@ -43,3 +47,22 @@ class ModelFittingError(Exception):
r"""Exception raised when attempts to fit a model terminate unsuccessfully."""

pass


class OptimizationTimeoutError(BotorchError):
r"""Exception raised when optimization times out."""

def __init__(
self, /, *args: Any, current_x: np.ndarray, runtime: float, **kwargs: Any
) -> None:
r"""
Args:
*args: Standard args to `BoTorchError`.
current_x: A numpy array representing the current iterate.
runtime: The total runtime in seconds after which the optimization
timed out.
**kwargs: Standard kwargs to `BoTorchError`.
"""
super().__init__(*args, **kwargs)
self.current_x = current_x
self.runtime = runtime
2 changes: 2 additions & 0 deletions botorch/optim/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_tensors_as_ndarray_1d,
set_tensors_from_ndarray_1d,
)
from botorch.optim.utils.timeout import minimize_with_timeout

__all__ = [
"_filter_kwargs",
Expand All @@ -48,6 +49,7 @@
"get_parameters_and_bounds",
"get_tensors_as_ndarray_1d",
"get_X_baseline",
"minimize_with_timeout",
"sample_all_priors",
"set_tensors_from_ndarray_1d",
"TorchAttr",
Expand Down
108 changes: 108 additions & 0 deletions botorch/optim/utils/timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import time
import warnings
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np
from botorch.exceptions.errors import OptimizationTimeoutError
from botorch.exceptions.warnings import OptimizationWarning
from scipy import optimize


def minimize_with_timeout(
fun: Callable[[np.ndarray, *Any], float],
x0: np.ndarray,
args: Tuple[Any, ...] = (),
method: Optional[str] = None,
jac: Optional[Union[str, Callable, bool]] = None,
hess: Optional[Union[str, Callable, optimize.HessianUpdateStrategy]] = None,
hessp: Optional[Callable] = None,
bounds: Optional[Union[Sequence[Tuple[float, float]], optimize.Bounds]] = None,
constraints=(), # Typing this properly is a s**t job
tol: Optional[float] = None,
callback: Optional[Callable] = None,
options: Optional[Dict[str, Any]] = None,
timeout_sec: Optional[float] = None,
) -> optimize.OptimizeResult:
r"""Wrapper around scipy.optimize.minimize to support timeout.
This method calls scipy.optimize.minimize with all arguments forwarded
verbatim. The only difference is that if provided a `timeout_sec` argument,
it will automatically stop the optimziation after the timeout is reached.
Internally, this is achieved by automatically constructing a wrapper callback
method that is injected to the scipy.optimize.minimize call and that keeps
track of the runtime and the optimization variables at the current iteration.
"""
if timeout_sec:

start_time = time.monotonic()
callback_data = {"num_iterations": 0} # update from withing callback below

def timeout_callback(xk: np.ndarray) -> bool:
runtime = time.monotonic() - start_time
callback_data["num_iterations"] += 1
if runtime > timeout_sec:
raise OptimizationTimeoutError(current_x=xk, runtime=runtime)
return False

if callback is None:
wrapped_callback = timeout_callback

elif callable(method):
raise NotImplementedError(
"Custom callable not supported for `method` argument."
)

elif method == "trust-constr": # special signature

def wrapped_callback(
xk: np.ndarray, state: optimize.OptimizeResult
) -> bool:
# order here is important to make sure base callback gets executed
return callback(xk, state) or timeout_callback(xk=xk)

else:

def wrapped_callback(xk: np.ndarray) -> None:
timeout_callback(xk=xk)
callback(xk)

else:
wrapped_callback = callback

try:
return optimize.minimize(
fun=fun,
x0=x0,
args=args,
method=method,
jac=jac,
hess=hess,
hessp=hessp,
bounds=bounds,
constraints=constraints,
tol=tol,
callback=wrapped_callback,
options=options,
)
except OptimizationTimeoutError as e:
msg = f"Optimization timed out after {e.runtime} seconds."
warnings.warn(msg, OptimizationWarning)
current_fun, *_ = fun(e.current_x, *args)

return optimize.OptimizeResult(
fun=current_fun,
x=e.current_x,
nit=callback_data["num_iterations"],
success=False, # same as when maxiter is reached
status=1, # same as when L-BFGS-B reaches maxiter
message=msg,
)
5 changes: 5 additions & 0 deletions sphinx/source/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ Numpy - Torch Conversion Tools
.. automodule:: botorch.optim.utils.numpy_utils
:members:

Optimization with Timeouts
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.optim.utils.timeout
:members:

Numpy - Torch Conversion Tools (OLD)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.optim.numpy_converter
Expand Down
12 changes: 12 additions & 0 deletions test/exceptions/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np

from botorch.exceptions.errors import (
BotorchError,
BotorchTensorDimensionError,
CandidateGenerationError,
InputDataError,
OptimizationTimeoutError,
UnsupportedError,
)
from botorch.utils.testing import BotorchTestCase
Expand All @@ -32,3 +35,12 @@ def test_raise_botorch_exceptions(self):
):
with self.assertRaises(ErrorClass):
raise ErrorClass("message")

def test_OptimizationTimeoutError(self):
error = OptimizationTimeoutError(
"message", current_x=np.array([1.0]), runtime=0.123
)
self.assertEqual(error.runtime, 0.123)
self.assertTrue(np.array_equal(error.current_x, np.array([1.0])))
with self.assertRaises(OptimizationTimeoutError):
raise error
5 changes: 5 additions & 0 deletions test/optim/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
99 changes: 99 additions & 0 deletions test/optim/utils/test_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import time

import numpy as np
from botorch.optim.utils.timeout import minimize_with_timeout
from botorch.utils.testing import BotorchTestCase
from scipy.optimize import OptimizeResult


class TestMinimizeWithTimeout(BotorchTestCase):
def test_minimize_with_timeout(self):
def f_and_g(x: np.ndarray, sleep_sec: float = 0.0):
time.sleep(sleep_sec)
return x**2, 2 * x

base_kwargs = {
"fun": f_and_g,
"x0": np.array([1.0]),
"method": "L-BFGS-B",
"jac": True,
"bounds": [(-2.0, 2.0)],
}

with self.subTest("test w/o timeout"):
res = minimize_with_timeout(**base_kwargs)
self.assertTrue(res.success)
self.assertAlmostEqual(res.fun, 0.0)
self.assertAlmostEqual(res.x, 0.0)
self.assertEqual(res.nit, 2) # quadratic approx. is exact

with self.subTest("test w/ non-binding timeout"):
res = minimize_with_timeout(**base_kwargs, timeout_sec=1.0)
self.assertTrue(res.success)
self.assertAlmostEqual(res.fun, 0.0)
self.assertAlmostEqual(res.x, 0.0)
self.assertEqual(res.nit, 2) # quadratic approx. is exact

with self.subTest("test w/ binding timeout"):
res = minimize_with_timeout(**base_kwargs, args=(1e-3,), timeout_sec=1e-4)
self.assertFalse(res.success)
self.assertEqual(res.nit, 1) # only one call to the callback is made

# set up callback with mutable object to verify callback execution
check_set = set()

def callback(x: np.ndarray) -> None:
check_set.add("foo")

with self.subTest("test w/ callout argument and non-binding timeout"):
res = minimize_with_timeout(
**base_kwargs, callback=callback, timeout_sec=1.0
)
self.assertTrue(res.success)
self.assertTrue("foo" in check_set)

# set up callback for method `trust-constr` w/ different signature
check_set.clear()
self.assertFalse("foo" in check_set)

def callback_trustconstr(x: np.ndarray, state: OptimizeResult) -> bool:
check_set.add("foo")
return False

with self.subTest("test `trust-constr` method w/ callback"):
res = minimize_with_timeout(
**{**base_kwargs, "method": "trust-constr"},
callback=callback_trustconstr,
)
self.assertTrue(res.success)
self.assertTrue("foo" in check_set)

# reset check set
check_set.clear()
self.assertFalse("foo" in check_set)

with self.subTest("test `trust-constr` method w/ callback and timeout"):
res = minimize_with_timeout(
**{**base_kwargs, "method": "trust-constr"},
args=(1e-3,),
callback=callback_trustconstr,
timeout_sec=1e-4,
)
self.assertFalse(res.success)
self.assertTrue("foo" in check_set)

with self.subTest("verify error if passing callable for `method` w/ timeout"):
with self.assertRaisesRegex(
NotImplementedError, "Custom callable not supported"
):
minimize_with_timeout(
**{**base_kwargs, "method": lambda *args, **kwargs: None},
callback=callback,
timeout_sec=1e-4,
)

0 comments on commit 171f5bc

Please sign in to comment.