From 171f5bc832f7431200797766a2bb07f7179b4480 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Wed, 28 Dec 2022 10:00:31 -0800 Subject: [PATCH] Add a minimize_with_timeout wrapper for scipy.optimize.minimize (#1403) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1403 Unfortunately, scipy does not allow timing out the optimization based on wall time. This diff implements a lightweight wrapper around `scipy.optimize.minimize` to achieve this. The new `minimize_with_timeout` method calls `scipy.optimize.minimize` with all arguments forwarded verbatim. The only difference is that if provided a `timeout_sec` kwarg, it automatically stops the optimization after the timeout is reached. Internally, this is achieved by automatically constructing a callback method that is injected to the `scipy.optimize.minimize` call that keeps track of the runtime and is used to extract the value of the optimization variables at the current iteration. Differential Revision: D39529835 fbshipit-source-id: 50f773bfa2fb5b50ab9c6d724d6fc4a67e4f6a6e --- botorch/exceptions/__init__.py | 2 + botorch/exceptions/errors.py | 23 +++++++ botorch/optim/utils/__init__.py | 2 + botorch/optim/utils/timeout.py | 108 +++++++++++++++++++++++++++++++ sphinx/source/optim.rst | 5 ++ test/exceptions/test_errors.py | 12 ++++ test/optim/utils/__init__.py | 5 ++ test/optim/utils/test_timeout.py | 99 ++++++++++++++++++++++++++++ 8 files changed, 256 insertions(+) create mode 100644 botorch/optim/utils/timeout.py create mode 100644 test/optim/utils/test_timeout.py diff --git a/botorch/exceptions/__init__.py b/botorch/exceptions/__init__.py index 4e5f0a8049..6de269f703 100644 --- a/botorch/exceptions/__init__.py +++ b/botorch/exceptions/__init__.py @@ -10,6 +10,7 @@ CandidateGenerationError, InputDataError, ModelFittingError, + OptimizationTimeoutError, UnsupportedError, ) from botorch.exceptions.warnings import ( @@ -35,6 +36,7 @@ "BadInitialCandidatesWarning", "CandidateGenerationError", "ModelFittingError", + "OptimizationTimeoutError", "OptimizationWarning", "SamplingWarning", "UnsupportedError", diff --git a/botorch/exceptions/errors.py b/botorch/exceptions/errors.py index a7c8872c85..9b26692932 100644 --- a/botorch/exceptions/errors.py +++ b/botorch/exceptions/errors.py @@ -8,6 +8,10 @@ Botorch Errors. """ +from typing import Any + +import numpy as np + class BotorchError(Exception): r"""Base botorch exception.""" @@ -43,3 +47,22 @@ class ModelFittingError(Exception): r"""Exception raised when attempts to fit a model terminate unsuccessfully.""" pass + + +class OptimizationTimeoutError(BotorchError): + r"""Exception raised when optimization times out.""" + + def __init__( + self, /, *args: Any, current_x: np.ndarray, runtime: float, **kwargs: Any + ) -> None: + r""" + Args: + *args: Standard args to `BoTorchError`. + current_x: A numpy array representing the current iterate. + runtime: The total runtime in seconds after which the optimization + timed out. + **kwargs: Standard kwargs to `BoTorchError`. + """ + super().__init__(*args, **kwargs) + self.current_x = current_x + self.runtime = runtime diff --git a/botorch/optim/utils/__init__.py b/botorch/optim/utils/__init__.py index 24fe6b5c27..552363898a 100644 --- a/botorch/optim/utils/__init__.py +++ b/botorch/optim/utils/__init__.py @@ -31,6 +31,7 @@ get_tensors_as_ndarray_1d, set_tensors_from_ndarray_1d, ) +from botorch.optim.utils.timeout import minimize_with_timeout __all__ = [ "_filter_kwargs", @@ -48,6 +49,7 @@ "get_parameters_and_bounds", "get_tensors_as_ndarray_1d", "get_X_baseline", + "minimize_with_timeout", "sample_all_priors", "set_tensors_from_ndarray_1d", "TorchAttr", diff --git a/botorch/optim/utils/timeout.py b/botorch/optim/utils/timeout.py new file mode 100644 index 0000000000..d0f7bd7f07 --- /dev/null +++ b/botorch/optim/utils/timeout.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import time +import warnings +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union + +import numpy as np +from botorch.exceptions.errors import OptimizationTimeoutError +from botorch.exceptions.warnings import OptimizationWarning +from scipy import optimize + + +def minimize_with_timeout( + fun: Callable[[np.ndarray, *Any], float], + x0: np.ndarray, + args: Tuple[Any, ...] = (), + method: Optional[str] = None, + jac: Optional[Union[str, Callable, bool]] = None, + hess: Optional[Union[str, Callable, optimize.HessianUpdateStrategy]] = None, + hessp: Optional[Callable] = None, + bounds: Optional[Union[Sequence[Tuple[float, float]], optimize.Bounds]] = None, + constraints=(), # Typing this properly is a s**t job + tol: Optional[float] = None, + callback: Optional[Callable] = None, + options: Optional[Dict[str, Any]] = None, + timeout_sec: Optional[float] = None, +) -> optimize.OptimizeResult: + r"""Wrapper around scipy.optimize.minimize to support timeout. + + This method calls scipy.optimize.minimize with all arguments forwarded + verbatim. The only difference is that if provided a `timeout_sec` argument, + it will automatically stop the optimziation after the timeout is reached. + + Internally, this is achieved by automatically constructing a wrapper callback + method that is injected to the scipy.optimize.minimize call and that keeps + track of the runtime and the optimization variables at the current iteration. + """ + if timeout_sec: + + start_time = time.monotonic() + callback_data = {"num_iterations": 0} # update from withing callback below + + def timeout_callback(xk: np.ndarray) -> bool: + runtime = time.monotonic() - start_time + callback_data["num_iterations"] += 1 + if runtime > timeout_sec: + raise OptimizationTimeoutError(current_x=xk, runtime=runtime) + return False + + if callback is None: + wrapped_callback = timeout_callback + + elif callable(method): + raise NotImplementedError( + "Custom callable not supported for `method` argument." + ) + + elif method == "trust-constr": # special signature + + def wrapped_callback( + xk: np.ndarray, state: optimize.OptimizeResult + ) -> bool: + # order here is important to make sure base callback gets executed + return callback(xk, state) or timeout_callback(xk=xk) + + else: + + def wrapped_callback(xk: np.ndarray) -> None: + timeout_callback(xk=xk) + callback(xk) + + else: + wrapped_callback = callback + + try: + return optimize.minimize( + fun=fun, + x0=x0, + args=args, + method=method, + jac=jac, + hess=hess, + hessp=hessp, + bounds=bounds, + constraints=constraints, + tol=tol, + callback=wrapped_callback, + options=options, + ) + except OptimizationTimeoutError as e: + msg = f"Optimization timed out after {e.runtime} seconds." + warnings.warn(msg, OptimizationWarning) + current_fun, *_ = fun(e.current_x, *args) + + return optimize.OptimizeResult( + fun=current_fun, + x=e.current_x, + nit=callback_data["num_iterations"], + success=False, # same as when maxiter is reached + status=1, # same as when L-BFGS-B reaches maxiter + message=msg, + ) diff --git a/sphinx/source/optim.rst b/sphinx/source/optim.rst index bd8ac5bab6..e6b9b3e11a 100644 --- a/sphinx/source/optim.rst +++ b/sphinx/source/optim.rst @@ -73,6 +73,11 @@ Numpy - Torch Conversion Tools .. automodule:: botorch.optim.utils.numpy_utils :members: +Optimization with Timeouts +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.optim.utils.timeout + :members: + Numpy - Torch Conversion Tools (OLD) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.optim.numpy_converter diff --git a/test/exceptions/test_errors.py b/test/exceptions/test_errors.py index 54ba0f3447..12ed03bee1 100644 --- a/test/exceptions/test_errors.py +++ b/test/exceptions/test_errors.py @@ -4,11 +4,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import numpy as np + from botorch.exceptions.errors import ( BotorchError, BotorchTensorDimensionError, CandidateGenerationError, InputDataError, + OptimizationTimeoutError, UnsupportedError, ) from botorch.utils.testing import BotorchTestCase @@ -32,3 +35,12 @@ def test_raise_botorch_exceptions(self): ): with self.assertRaises(ErrorClass): raise ErrorClass("message") + + def test_OptimizationTimeoutError(self): + error = OptimizationTimeoutError( + "message", current_x=np.array([1.0]), runtime=0.123 + ) + self.assertEqual(error.runtime, 0.123) + self.assertTrue(np.array_equal(error.current_x, np.array([1.0]))) + with self.assertRaises(OptimizationTimeoutError): + raise error diff --git a/test/optim/utils/__init__.py b/test/optim/utils/__init__.py index e69de29bb2..4b87eb9e4d 100644 --- a/test/optim/utils/__init__.py +++ b/test/optim/utils/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/optim/utils/test_timeout.py b/test/optim/utils/test_timeout.py new file mode 100644 index 0000000000..32178117e6 --- /dev/null +++ b/test/optim/utils/test_timeout.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import numpy as np +from botorch.optim.utils.timeout import minimize_with_timeout +from botorch.utils.testing import BotorchTestCase +from scipy.optimize import OptimizeResult + + +class TestMinimizeWithTimeout(BotorchTestCase): + def test_minimize_with_timeout(self): + def f_and_g(x: np.ndarray, sleep_sec: float = 0.0): + time.sleep(sleep_sec) + return x**2, 2 * x + + base_kwargs = { + "fun": f_and_g, + "x0": np.array([1.0]), + "method": "L-BFGS-B", + "jac": True, + "bounds": [(-2.0, 2.0)], + } + + with self.subTest("test w/o timeout"): + res = minimize_with_timeout(**base_kwargs) + self.assertTrue(res.success) + self.assertAlmostEqual(res.fun, 0.0) + self.assertAlmostEqual(res.x, 0.0) + self.assertEqual(res.nit, 2) # quadratic approx. is exact + + with self.subTest("test w/ non-binding timeout"): + res = minimize_with_timeout(**base_kwargs, timeout_sec=1.0) + self.assertTrue(res.success) + self.assertAlmostEqual(res.fun, 0.0) + self.assertAlmostEqual(res.x, 0.0) + self.assertEqual(res.nit, 2) # quadratic approx. is exact + + with self.subTest("test w/ binding timeout"): + res = minimize_with_timeout(**base_kwargs, args=(1e-3,), timeout_sec=1e-4) + self.assertFalse(res.success) + self.assertEqual(res.nit, 1) # only one call to the callback is made + + # set up callback with mutable object to verify callback execution + check_set = set() + + def callback(x: np.ndarray) -> None: + check_set.add("foo") + + with self.subTest("test w/ callout argument and non-binding timeout"): + res = minimize_with_timeout( + **base_kwargs, callback=callback, timeout_sec=1.0 + ) + self.assertTrue(res.success) + self.assertTrue("foo" in check_set) + + # set up callback for method `trust-constr` w/ different signature + check_set.clear() + self.assertFalse("foo" in check_set) + + def callback_trustconstr(x: np.ndarray, state: OptimizeResult) -> bool: + check_set.add("foo") + return False + + with self.subTest("test `trust-constr` method w/ callback"): + res = minimize_with_timeout( + **{**base_kwargs, "method": "trust-constr"}, + callback=callback_trustconstr, + ) + self.assertTrue(res.success) + self.assertTrue("foo" in check_set) + + # reset check set + check_set.clear() + self.assertFalse("foo" in check_set) + + with self.subTest("test `trust-constr` method w/ callback and timeout"): + res = minimize_with_timeout( + **{**base_kwargs, "method": "trust-constr"}, + args=(1e-3,), + callback=callback_trustconstr, + timeout_sec=1e-4, + ) + self.assertFalse(res.success) + self.assertTrue("foo" in check_set) + + with self.subTest("verify error if passing callable for `method` w/ timeout"): + with self.assertRaisesRegex( + NotImplementedError, "Custom callable not supported" + ): + minimize_with_timeout( + **{**base_kwargs, "method": lambda *args, **kwargs: None}, + callback=callback, + timeout_sec=1e-4, + )