diff --git a/botorch/fit.py b/botorch/fit.py index 1e67c89c86..36db2b3a53 100644 --- a/botorch/fit.py +++ b/botorch/fit.py @@ -48,10 +48,26 @@ from torch.utils.data import DataLoader +def _debug_warn(w: WarningMessage) -> bool: + if _LBFGSB_MAXITER_MAXFUN_REGEX.search(str(w.message)): + return True + # TODO: Better handle cases where warning handling logic + # affects both debug and rethrow functions. + return False + + +def _rethrow_warn(w: WarningMessage) -> bool: + if not issubclass(w.category, OptimizationWarning): + return True + if "Optimization timed out after" in str(w.message): + return True + return False + + DEFAULT_WARNING_HANDLER = partial( _warning_handler_template, - debug=lambda w: _LBFGSB_MAXITER_MAXFUN_REGEX.search(str(w.message)), - rethrow=lambda w: not issubclass(w.category, OptimizationWarning), + debug=_debug_warn, + rethrow=_rethrow_warn, ) FitGPyTorchMLL = Dispatcher("fit_gpytorch_mll", encoder=type_bypassing_encoder) @@ -188,9 +204,9 @@ def _fit_fallback( optimizer_kwargs: Keyword arguments passed to `optimizer`. max_attempts: The maximum number of fit attempts allowed. The attempt budget is NOT shared between calls to this method. - warning_filter: A function used to filter warnings produced when calling - `optimizer`. Any unfiltered warnings will be rethrown and trigger a - model fitting retry. + warning_handler: A function used to filter warnings produced when calling + `optimizer`. Any unfiltered warnings (those for which `warning_handler` + returns `False`) will be rethrown and trigger a model fitting retry. caught_exception_types: A tuple of exception types whose instances should be redirected to `logging.DEBUG`. **ignore: This function ignores unrecognized keyword arguments. diff --git a/botorch/optim/core.py b/botorch/optim/core.py index 49cb92831d..97afdb59f1 100644 --- a/botorch/optim/core.py +++ b/botorch/optim/core.py @@ -17,9 +17,9 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from botorch.optim.closures import NdarrayOptimizationClosure -from botorch.optim.utils import get_bounds_as_ndarray +from botorch.optim.utils.numpy_utils import get_bounds_as_ndarray +from botorch.optim.utils.timeout import minimize_with_timeout from numpy import asarray, float64 as np_float64, ndarray -from scipy.optimize import minimize from torch import Tensor from torch.optim.adam import Adam from torch.optim.optimizer import Optimizer @@ -62,6 +62,7 @@ def scipy_minimize( x0: Optional[ndarray] = None, method: str = "L-BFGS-B", options: Optional[Dict[str, Any]] = None, + timeout_sec: Optional[float] = None, ) -> OptimizationResult: r"""Generic scipy.optimize.minimize-based optimization routine. @@ -74,6 +75,8 @@ def scipy_minimize( x0: An optional initialization vector passed to scipy.optimize.minimize. method: Solver type, passed along to scipy.minimize. options: Dictionary of solver options, passed along to scipy.minimize. + timeout_sec: Timeout in seconds to wait before aborting the optimization loop + if not converged (will return the best found solution thus far). Returns: An OptimizationResult summarizing the final state of the run. @@ -103,7 +106,7 @@ def wrapped_callback(x: ndarray): ) return callback(parameters, result) # pyre-ignore [29] - raw = minimize( + raw = minimize_with_timeout( wrapped_closure, wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False), jac=True, @@ -111,6 +114,7 @@ def wrapped_callback(x: ndarray): method=method, options=options, callback=wrapped_callback, + timeout_sec=timeout_sec, ) # Post-processing and outcome handling @@ -122,6 +126,7 @@ def wrapped_callback(x: ndarray): status = ( # Check whether we stopped due to reaching maxfun or maxiter OptimizationStatus.STOPPED if _LBFGSB_MAXITER_MAXFUN_REGEX.search(msg) + or "Optimization timed out after" in msg else OptimizationStatus.FAILURE ) @@ -142,6 +147,7 @@ def torch_minimize( optimizer: Union[Optimizer, Callable[[List[Tensor]], Optimizer]] = Adam, scheduler: Optional[Union[LRScheduler, Callable[[Optimizer], LRScheduler]]] = None, step_limit: Optional[int] = None, + timeout_sec: Optional[float] = None, stopping_criterion: Optional[Callable[[Tensor], bool]] = None, ) -> OptimizationResult: r"""Generic torch.optim-based optimization routine. @@ -152,20 +158,24 @@ def torch_minimize( parameters: A dictionary of tensors to be optimized. bounds: An optional dictionary of bounds for elements of `parameters`. callback: A callable taking `parameters` and an OptimizationResult as arguments. - step_limit: Integer specifying a maximum number of optimization steps. - One of `step_limit` or `stopping_criterion` must be passed. - stopping_criterion: A StoppingCriterion for the optimization loop. optimizer: A `torch.optim.Optimizer` instance or a factory that takes a list of parameters and returns an `Optimizer` instance. scheduler: A `torch.optim.lr_scheduler._LRScheduler` instance or a factory that takes a `Optimizer` instance and returns a `_LRSchedule` instance. + step_limit: Integer specifying a maximum number of optimization steps. + One of `step_limit`, `stopping_criterion`, or `timeout_sec` must be passed. + timeout_sec: Timeout in seconds before terminating the optimization loop. + One of `step_limit`, `stopping_criterion`, or `timeout_sec` must be passed. + stopping_criterion: A StoppingCriterion for the optimization loop. Returns: An OptimizationResult summarizing the final state of the run. """ + result: OptimizationResult start_time = monotonic() + if step_limit is None: - if stopping_criterion is None: + if stopping_criterion is None and timeout_sec is None: raise RuntimeError("No termination conditions were given.") step_limit = maxsize @@ -180,14 +190,14 @@ def torch_minimize( if bounds is None else {name: limits for name, limits in bounds.items() if name in parameters} ) - result: OptimizationResult - for step in range(step_limit): + for step in range(1, step_limit + 1): fval, _ = closure() + runtime = monotonic() - start_time result = OptimizationResult( step=step, fval=fval.detach().cpu().item(), status=OptimizationStatus.RUNNING, - runtime=monotonic() - start_time, + runtime=runtime, ) # TODO: Update stopping_criterion API to return a message. @@ -195,6 +205,12 @@ def torch_minimize( result.status = OptimizationStatus.STOPPED result.message = "`torch_minimize` stopped due to `stopping_criterion`." + if timeout_sec is not None and runtime >= timeout_sec: + result.status = OptimizationStatus.STOPPED + result.message = ( + f"`torch_minimize` stopped due to timeout after {runtime} seconds." + ) + if callback: callback(parameters, result) @@ -213,7 +229,7 @@ def torch_minimize( # Account for final parameter update when stopping due to step_limit return OptimizationResult( - step=step + 1, + step=step, fval=closure()[0].detach().cpu().item(), status=OptimizationStatus.STOPPED, runtime=monotonic() - start_time, diff --git a/botorch/optim/fit.py b/botorch/optim/fit.py index 8f507a75bb..c9925bc74e 100644 --- a/botorch/optim/fit.py +++ b/botorch/optim/fit.py @@ -79,6 +79,7 @@ def fit_gpytorch_mll_scipy( method: str = "L-BFGS-B", options: Optional[Dict[str, Any]] = None, callback: Optional[Callable[[Dict[str, Tensor], OptimizationResult], None]] = None, + timeout_sec: Optional[float] = None, ) -> OptimizationResult: r"""Generic scipy.optimized-based fitting routine for GPyTorch MLLs. @@ -98,6 +99,8 @@ def fit_gpytorch_mll_scipy( options: Dictionary of solver options, passed along to scipy.minimize. callback: Optional callback taking `parameters` and an OptimizationResult as its sole arguments. + timeout_sec: Timeout in seconds after which to terminate the fitting loop + (note that timing out can result in bad fits!). Returns: The final OptimizationResult. @@ -121,6 +124,7 @@ def fit_gpytorch_mll_scipy( method=method, options=options, callback=callback, + timeout_sec=timeout_sec, ) if result.status != OptimizationStatus.SUCCESS: warn( @@ -143,6 +147,7 @@ def fit_gpytorch_mll_torch( optimizer: Union[Optimizer, Callable[..., Optimizer]] = Adam, scheduler: Optional[Union[_LRScheduler, Callable[..., _LRScheduler]]] = None, callback: Optional[Callable[[Dict[str, Tensor], OptimizationResult], None]] = None, + timeout_sec: Optional[float] = None, ) -> OptimizationResult: r"""Generic torch.optim-based fitting routine for GPyTorch MLLs. @@ -164,6 +169,8 @@ def fit_gpytorch_mll_torch( that takes an `Optimizer` instance and returns an `_LRSchedule`. callback: Optional callback taking `parameters` and an OptimizationResult as its sole arguments. + timeout_sec: Timeout in seconds after which to terminate the fitting loop + (note that timing out can result in bad fits!). Returns: The final OptimizationResult. @@ -191,6 +198,7 @@ def fit_gpytorch_mll_torch( step_limit=step_limit, stopping_criterion=stopping_criterion, callback=callback, + timeout_sec=timeout_sec, ) diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index 55cc0c083c..3dbf3c0357 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -10,6 +10,7 @@ from __future__ import annotations +import time import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -130,6 +131,9 @@ def optimize_acqf( >>> qEI, bounds, 3, 15, 256, sequential=True >>> ) """ + start_time: float = time.monotonic() + timeout_sec = kwargs.get("timeout_sec") + if inequality_constraints is None: if not (bounds.ndim == 2 and bounds.shape[0] == 2): raise ValueError( @@ -175,7 +179,12 @@ def optimize_acqf( else gen_batch_initial_conditions ) + # Perform sequential optimization via successive conditioning on pending points if sequential and q > 1: + if timeout_sec is not None: + # When using sequential optimization, we allocate the total timeout + # evenly across the individual acquisition optimizations. + timeout_sec = (timeout_sec - start_time) / q if not return_best_only: raise NotImplementedError( "`return_best_only=False` only supported for joint optimization." @@ -188,6 +197,7 @@ def optimize_acqf( candidate_list, acq_value_list = [], [] base_X_pending = acq_function.X_pending for i in range(q): + candidate, acq_value = optimize_acqf( acq_function=acq_function, bounds=bounds, @@ -204,6 +214,7 @@ def optimize_acqf( return_best_only=True, sequential=False, ic_generator=ic_gen, + timeout_sec=timeout_sec, ) candidate_list.append(candidate) @@ -219,6 +230,7 @@ def optimize_acqf( acq_function.set_X_pending(base_X_pending) return candidates, torch.stack(acq_value_list) + # Batch optimization (including the case q=1) options = options or {} # Handle the trivial case when all features are fixed @@ -252,22 +264,27 @@ def optimize_acqf( "batch_limit", num_restarts if not nonlinear_inequality_constraints else 1 ) - def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]: + def _optimize_batch_candidates( + timeout_sec: Optional[float], + ) -> Tuple[Tensor, Tensor, List[Warning]]: batch_candidates_list: List[Tensor] = [] batch_acq_values_list: List[Tensor] = [] batched_ics = batch_initial_conditions.split(batch_limit) opt_warnings = [] - - scipy_kws = dict( - acquisition_function=acq_function, - lower_bounds=None if bounds[0].isinf().all() else bounds[0], - upper_bounds=None if bounds[1].isinf().all() else bounds[1], - options={k: v for k, v in options.items() if k not in INIT_OPTION_KEYS}, - inequality_constraints=inequality_constraints, - equality_constraints=equality_constraints, - nonlinear_inequality_constraints=nonlinear_inequality_constraints, - fixed_features=fixed_features, - ) + if timeout_sec is not None: + timeout_sec = (timeout_sec - start_time) / len(batched_ics) + + scipy_kws = { + "acquisition_function": acq_function, + "lower_bounds": None if bounds[0].isinf().all() else bounds[0], + "upper_bounds": None if bounds[1].isinf().all() else bounds[1], + "options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS}, + "inequality_constraints": inequality_constraints, + "equality_constraints": equality_constraints, + "nonlinear_inequality_constraints": nonlinear_inequality_constraints, + "fixed_features": fixed_features, + "timeout_sec": timeout_sec, + } for i, batched_ics_ in enumerate(batched_ics): # optimize using random restart optimization @@ -285,7 +302,7 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]: batch_acq_values = torch.cat(batch_acq_values_list) return batch_candidates, batch_acq_values, opt_warnings - batch_candidates, batch_acq_values, ws = _optimize_batch_candidates() + batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(timeout_sec) optimization_warning_raised = any( (issubclass(w.category, OptimizationWarning) for w in ws) @@ -319,7 +336,9 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]: **kwargs, ) - batch_candidates, batch_acq_values, ws = _optimize_batch_candidates() + batch_candidates, batch_acq_values, ws = _optimize_batch_candidates( + timeout_sec + ) optimization_warning_raised = any( (issubclass(w.category, OptimizationWarning) for w in ws) diff --git a/test/optim/test_core.py b/test/optim/test_core.py index df55583c70..c9734ba40c 100644 --- a/test/optim/test_core.py +++ b/test/optim/test_core.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import time from functools import partial from typing import Dict from unittest.mock import MagicMock, patch @@ -46,6 +47,16 @@ def free_parameters(self) -> Dict[str, Tensor]: return {n: p for n, p in self.named_parameters() if p.requires_grad} +def norm_squared(x, delay: float = 0.0): + if x.grad is not None: + x.grad.zero_() + loss = x.square().sum() + loss.backward() + if delay: + time.sleep(delay) + return loss, [x.grad] + + class TestScipyMinimize(BotorchTestCase): def setUp(self): super().setUp() @@ -62,19 +73,19 @@ def setUp(self): def test_basic(self): x = Parameter(torch.rand([])) - - def closure(): - if x.grad is not None: - x.grad.zero_() - - loss = x.square().sum() - loss.backward() - return loss, [x.grad] - + closure = partial(norm_squared, x) result = scipy_minimize(closure, {"x": x}) self.assertEqual(result.status, OptimizationStatus.SUCCESS) self.assertTrue(allclose(result.fval, 0.0)) + def test_timeout(self): + x = Parameter(torch.tensor(1.0)) + # adding a small delay here to combat some timing issues on windows + closure = partial(norm_squared, x, delay=1e-3) + result = scipy_minimize(closure, {"x": x}, timeout_sec=1e-4) + self.assertEqual(result.status, OptimizationStatus.STOPPED) + self.assertTrue("Optimization timed out after" in result.message) + def test_main(self): def _callback(parameters, result, out) -> None: out.append(result) @@ -125,12 +136,12 @@ def _callback(parameters, result, out) -> None: def test_post_processing(self): closure = next(iter(self.closures.values())) wrapper = NdarrayOptimizationClosure(closure, closure.parameters) - with patch.object(core, "minimize") as mock_minimize: + with patch.object(core, "minimize_with_timeout") as mock_minimize_with_timeout: for status, msg in ( (OptimizationStatus.FAILURE, b"ABNORMAL_TERMINATION_IN_LNSRCH"), (OptimizationStatus.STOPPED, "TOTAL NO. of ITERATIONS REACHED LIMIT"), ): - mock_minimize.return_value = OptimizeResult( + mock_minimize_with_timeout.return_value = OptimizeResult( x=wrapper.state, fun=1.0, nit=3, @@ -139,7 +150,9 @@ def test_post_processing(self): ) result = core.scipy_minimize(wrapper, closure.parameters) self.assertEqual(result.status, status) - self.assertEqual(result.fval, mock_minimize.return_value.fun) + self.assertEqual( + result.fval, mock_minimize_with_timeout.return_value.fun + ) self.assertEqual( result.message, msg if isinstance(msg, str) else msg.decode("ascii") ) @@ -161,19 +174,19 @@ def setUp(self): def test_basic(self): x = Parameter(torch.tensor([0.02])) - - def closure(): - if x.grad is not None: - x.grad.zero_() - - loss = x.square().sum() - loss.backward() - return loss, [x.grad] - + closure = partial(norm_squared, x) result = torch_minimize(closure, {"x": x}, step_limit=100) self.assertEqual(result.status, OptimizationStatus.STOPPED) self.assertTrue(allclose(result.fval, 0.0)) + def test_timeout(self): + x = Parameter(torch.tensor(1.0)) + # adding a small delay here to combat some timing issues on windows + closure = partial(norm_squared, x, delay=1e-3) + result = torch_minimize(closure, {"x": x}, timeout_sec=1e-4) + self.assertEqual(result.status, OptimizationStatus.STOPPED) + self.assertTrue("stopped due to timeout after" in result.message) + def test_main(self): def _callback(parameters, result, out) -> None: out.append(result) @@ -239,7 +252,7 @@ def _callback(parameters, result, out) -> None: parameters=closure.parameters, stopping_criterion=lambda fval: next(stopping_decisions), ) - self.assertEqual(result.step, 2) + self.assertEqual(result.step, 3) self.assertEqual(result.status, OptimizationStatus.STOPPED) # Test passing `scheduler` diff --git a/test/optim/test_fit.py b/test/optim/test_fit.py index b4aca28520..57af8594cd 100644 --- a/test/optim/test_fit.py +++ b/test/optim/test_fit.py @@ -102,9 +102,9 @@ def _test_fit_gpytorch_mll_scipy(self, mll): mock_x.append(values.view(-1)) with module_rollback_ctx(mll, checkpoint=ckpt), patch.object( - core, "minimize" - ) as mock_minimize: - mock_minimize.return_value = OptimizeResult( + core, "minimize_with_timeout" + ) as mock_minimize_with_timeout: + mock_minimize_with_timeout.return_value = OptimizeResult( x=torch.concat(mock_x).tolist(), success=False, status=0, diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index 5ecdf8851c..24cf74b7ee 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -219,7 +219,10 @@ def test_optimize_acqf_joint( @mock.patch("botorch.optim.optimize.gen_batch_initial_conditions") @mock.patch("botorch.optim.optimize.gen_candidates_scipy") def test_optimize_acqf_sequential( - self, mock_gen_candidates_scipy, mock_gen_batch_initial_conditions + self, + mock_gen_candidates_scipy, + mock_gen_batch_initial_conditions, + timeout_sec=None, ): q = 3 num_restarts = 2 @@ -261,6 +264,7 @@ def test_optimize_acqf_sequential( inequality_constraints=inequality_constraints, post_processing_func=rounding_func, sequential=True, + timeout_sec=timeout_sec, ) self.assertTrue(torch.equal(candidates, expected_candidates)) self.assertTrue( @@ -312,6 +316,9 @@ def test_optimize_acqf_sequential( sequential=True, ) + def test_optimize_acqf_sequential_timeout(self): + self.test_optimize_acqf_sequential(timeout=1e-4) + def test_optimize_acqf_sequential_notimplemented(self): # Sequential acquisition function optimization only supported # when return_best_only=True