Skip to content

Commit

Permalink
Expose timeout option in higher-level optimziation wrappers (#1598)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1598

Now that we have the ability to time out the optimization in `scipy.optimize.minimize` at a lower-level, we can expose it also in the higher-level optimization wrappers.

Differential Revision: D42254406

fbshipit-source-id: 09680b257172280f991460933c6e76090622b320
  • Loading branch information
Balandat authored and facebook-github-bot committed Dec 28, 2022
1 parent 1cfd1d6 commit 6bd5ca4
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 43 deletions.
26 changes: 21 additions & 5 deletions botorch/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,26 @@
from torch.utils.data import DataLoader


def _debug_warn(w: WarningMessage) -> bool:
if _LBFGSB_MAXITER_MAXFUN_REGEX.search(str(w.message)):
return True
# TODO: Better handle cases where warning handling logic
# affects both debug and rethrow functions.
return False


def _rethrow_warn(w: WarningMessage) -> bool:
if not issubclass(w.category, OptimizationWarning):
return True
if "Optimization timed out after" in str(w.message):
return True
return False


DEFAULT_WARNING_HANDLER = partial(
_warning_handler_template,
debug=lambda w: _LBFGSB_MAXITER_MAXFUN_REGEX.search(str(w.message)),
rethrow=lambda w: not issubclass(w.category, OptimizationWarning),
debug=_debug_warn,
rethrow=_rethrow_warn,
)
FitGPyTorchMLL = Dispatcher("fit_gpytorch_mll", encoder=type_bypassing_encoder)

Expand Down Expand Up @@ -188,9 +204,9 @@ def _fit_fallback(
optimizer_kwargs: Keyword arguments passed to `optimizer`.
max_attempts: The maximum number of fit attempts allowed. The attempt budget
is NOT shared between calls to this method.
warning_filter: A function used to filter warnings produced when calling
`optimizer`. Any unfiltered warnings will be rethrown and trigger a
model fitting retry.
warning_handler: A function used to filter warnings produced when calling
`optimizer`. Any unfiltered warnings (those for which `warning_handler`
returns `False`) will be rethrown and trigger a model fitting retry.
caught_exception_types: A tuple of exception types whose instances should
be redirected to `logging.DEBUG`.
**ignore: This function ignores unrecognized keyword arguments.
Expand Down
48 changes: 33 additions & 15 deletions botorch/optim/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from botorch.optim.closures import NdarrayOptimizationClosure
from botorch.optim.utils import get_bounds_as_ndarray
from botorch.optim.utils.numpy_utils import get_bounds_as_ndarray
from botorch.optim.utils.timeout import minimize_with_timeout
from numpy import asarray, float64 as np_float64, ndarray
from scipy.optimize import minimize
from torch import Tensor
from torch.optim.adam import Adam
from torch.optim.optimizer import Optimizer
Expand All @@ -29,6 +29,11 @@
except ImportError: # pragma: no cover
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # pragma: no cover

TOptimizationClosure = Union[
Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]],
NdarrayOptimizationClosure,
]


_LBFGSB_MAXITER_MAXFUN_REGEX = re.compile( # regex for maxiter and maxfun messages
"TOTAL NO. of (ITERATIONS REACHED LIMIT|f AND g EVALUATIONS EXCEEDS LIMIT)"
Expand All @@ -52,16 +57,14 @@ class OptimizationResult:


def scipy_minimize(
closure: Union[
Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]],
NdarrayOptimizationClosure,
],
closure: TOptimizationClosure,
parameters: Dict[str, Tensor],
bounds: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]] = None,
callback: Optional[Callable[[Dict[str, Tensor], OptimizationResult], None]] = None,
x0: Optional[ndarray] = None,
method: str = "L-BFGS-B",
options: Optional[Dict[str, Any]] = None,
timeout_sec: Optional[float] = None,
) -> OptimizationResult:
r"""Generic scipy.optimize.minimize-based optimization routine.
Expand All @@ -74,6 +77,8 @@ def scipy_minimize(
x0: An optional initialization vector passed to scipy.optimize.minimize.
method: Solver type, passed along to scipy.minimize.
options: Dictionary of solver options, passed along to scipy.minimize.
timeout_sec: Timeout in seconds to wait before aborting the optimization loop
if not converged (will return the best found solution thus far).
Returns:
An OptimizationResult summarizing the final state of the run.
Expand Down Expand Up @@ -103,14 +108,15 @@ def wrapped_callback(x: ndarray):
)
return callback(parameters, result) # pyre-ignore [29]

raw = minimize(
raw = minimize_with_timeout(
wrapped_closure,
wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
jac=True,
bounds=bounds_np,
method=method,
options=options,
callback=wrapped_callback,
timeout_sec=timeout_sec,
)

# Post-processing and outcome handling
Expand All @@ -122,6 +128,7 @@ def wrapped_callback(x: ndarray):
status = ( # Check whether we stopped due to reaching maxfun or maxiter
OptimizationStatus.STOPPED
if _LBFGSB_MAXITER_MAXFUN_REGEX.search(msg)
or "Optimization timed out after" in msg
else OptimizationStatus.FAILURE
)

Expand All @@ -142,6 +149,7 @@ def torch_minimize(
optimizer: Union[Optimizer, Callable[[List[Tensor]], Optimizer]] = Adam,
scheduler: Optional[Union[LRScheduler, Callable[[Optimizer], LRScheduler]]] = None,
step_limit: Optional[int] = None,
timeout_sec: Optional[float] = None,
stopping_criterion: Optional[Callable[[Tensor], bool]] = None,
) -> OptimizationResult:
r"""Generic torch.optim-based optimization routine.
Expand All @@ -152,19 +160,23 @@ def torch_minimize(
parameters: A dictionary of tensors to be optimized.
bounds: An optional dictionary of bounds for elements of `parameters`.
callback: A callable taking `parameters` and an OptimizationResult as arguments.
step_limit: Integer specifying a maximum number of optimization steps.
One of `step_limit` or `stopping_criterion` must be passed.
stopping_criterion: A StoppingCriterion for the optimization loop.
optimizer: A `torch.optim.Optimizer` instance or a factory that takes
a list of parameters and returns an `Optimizer` instance.
scheduler: A `torch.optim.lr_scheduler._LRScheduler` instance or a factory
that takes a `Optimizer` instance and returns a `_LRSchedule` instance.
step_limit: Integer specifying a maximum number of optimization steps.
One of `step_limit`, `stopping_criterion`, or `timeout_sec` must be passed.
timeout_sec: Timeout in seconds before terminating the optimization loop.
One of `step_limit`, `stopping_criterion`, or `timeout_sec` must be passed.
stopping_criterion: A StoppingCriterion for the optimization loop.
Returns:
An OptimizationResult summarizing the final state of the run.
"""
result: OptimizationResult
start_time = monotonic()
if step_limit is None:

if step_limit is None and timeout_sec is None:
if stopping_criterion is None:
raise RuntimeError("No termination conditions were given.")
step_limit = maxsize
Expand All @@ -180,21 +192,27 @@ def torch_minimize(
if bounds is None
else {name: limits for name, limits in bounds.items() if name in parameters}
)
result: OptimizationResult
for step in range(step_limit):
for step in range(1, step_limit + 1):
fval, _ = closure()
runtime = monotonic() - start_time
result = OptimizationResult(
step=step,
fval=fval.detach().cpu().item(),
status=OptimizationStatus.RUNNING,
runtime=monotonic() - start_time,
runtime=runtime,
)

# TODO: Update stopping_criterion API to return a message.
if stopping_criterion and stopping_criterion(fval):
result.status = OptimizationStatus.STOPPED
result.message = "`torch_minimize` stopped due to `stopping_criterion`."

if timeout_sec is not None and runtime >= timeout_sec:
result.status = OptimizationStatus.STOPPED
result.message = (
f"`torch_minimize` stopped due to timeout after {runtime} seconds."
)

if callback:
callback(parameters, result)

Expand All @@ -213,7 +231,7 @@ def torch_minimize(

# Account for final parameter update when stopping due to step_limit
return OptimizationResult(
step=step + 1,
step=step,
fval=closure()[0].detach().cpu().item(),
status=OptimizationStatus.STOPPED,
runtime=monotonic() - start_time,
Expand Down
8 changes: 8 additions & 0 deletions botorch/optim/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def fit_gpytorch_mll_scipy(
method: str = "L-BFGS-B",
options: Optional[Dict[str, Any]] = None,
callback: Optional[Callable[[Dict[str, Tensor], OptimizationResult], None]] = None,
timeout_sec: Optional[float] = None,
) -> OptimizationResult:
r"""Generic scipy.optimized-based fitting routine for GPyTorch MLLs.
Expand All @@ -98,6 +99,8 @@ def fit_gpytorch_mll_scipy(
options: Dictionary of solver options, passed along to scipy.minimize.
callback: Optional callback taking `parameters` and an OptimizationResult as its
sole arguments.
timeout_sec: Timeout in seconds after which to terminate the fitting loop
(note that timing out can result in bad fits!).
Returns:
The final OptimizationResult.
Expand All @@ -121,6 +124,7 @@ def fit_gpytorch_mll_scipy(
method=method,
options=options,
callback=callback,
timeout_sec=timeout_sec,
)
if result.status != OptimizationStatus.SUCCESS:
warn(
Expand All @@ -143,6 +147,7 @@ def fit_gpytorch_mll_torch(
optimizer: Union[Optimizer, Callable[..., Optimizer]] = Adam,
scheduler: Optional[Union[_LRScheduler, Callable[..., _LRScheduler]]] = None,
callback: Optional[Callable[[Dict[str, Tensor], OptimizationResult], None]] = None,
timeout_sec: Optional[float] = None,
) -> OptimizationResult:
r"""Generic torch.optim-based fitting routine for GPyTorch MLLs.
Expand All @@ -164,6 +169,8 @@ def fit_gpytorch_mll_torch(
that takes an `Optimizer` instance and returns an `_LRSchedule`.
callback: Optional callback taking `parameters` and an OptimizationResult as its
sole arguments.
timeout_sec: Timeout in seconds after which to terminate the fitting loop
(note that timing out can result in bad fits!).
Returns:
The final OptimizationResult.
Expand Down Expand Up @@ -191,6 +198,7 @@ def fit_gpytorch_mll_torch(
step_limit=step_limit,
stopping_criterion=stopping_criterion,
callback=callback,
timeout_sec=timeout_sec,
)


Expand Down
47 changes: 33 additions & 14 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from __future__ import annotations

import time
import warnings

from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -130,6 +131,9 @@ def optimize_acqf(
>>> qEI, bounds, 3, 15, 256, sequential=True
>>> )
"""
start_time: float = time.monotonic()
timeout_sec = kwargs.get("timeout_sec")

if inequality_constraints is None:
if not (bounds.ndim == 2 and bounds.shape[0] == 2):
raise ValueError(
Expand Down Expand Up @@ -175,7 +179,12 @@ def optimize_acqf(
else gen_batch_initial_conditions
)

# Perform sequential optimization via successive conditioning on pending points
if sequential and q > 1:
if timeout_sec is not None:
# When using sequential optimization, we allocate the total timeout
# evenly across the individual acquisition optimizations.
timeout_sec = (timeout_sec - start_time) / q
if not return_best_only:
raise NotImplementedError(
"`return_best_only=False` only supported for joint optimization."
Expand All @@ -188,6 +197,7 @@ def optimize_acqf(
candidate_list, acq_value_list = [], []
base_X_pending = acq_function.X_pending
for i in range(q):

candidate, acq_value = optimize_acqf(
acq_function=acq_function,
bounds=bounds,
Expand All @@ -204,6 +214,7 @@ def optimize_acqf(
return_best_only=True,
sequential=False,
ic_generator=ic_gen,
timeout_sec=timeout_sec,
)

candidate_list.append(candidate)
Expand All @@ -219,6 +230,7 @@ def optimize_acqf(
acq_function.set_X_pending(base_X_pending)
return candidates, torch.stack(acq_value_list)

# Batch optimization (including the case q=1)
options = options or {}

# Handle the trivial case when all features are fixed
Expand Down Expand Up @@ -252,22 +264,27 @@ def optimize_acqf(
"batch_limit", num_restarts if not nonlinear_inequality_constraints else 1
)

def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
def _optimize_batch_candidates(
timeout_sec: Optional[float],
) -> Tuple[Tensor, Tensor, List[Warning]]:
batch_candidates_list: List[Tensor] = []
batch_acq_values_list: List[Tensor] = []
batched_ics = batch_initial_conditions.split(batch_limit)
opt_warnings = []

scipy_kws = dict(
acquisition_function=acq_function,
lower_bounds=None if bounds[0].isinf().all() else bounds[0],
upper_bounds=None if bounds[1].isinf().all() else bounds[1],
options={k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
fixed_features=fixed_features,
)
if timeout_sec is not None:
timeout_sec = (timeout_sec - start_time) / len(batched_ics)

scipy_kws = {
"acquisition_function": acq_function,
"lower_bounds": None if bounds[0].isinf().all() else bounds[0],
"upper_bounds": None if bounds[1].isinf().all() else bounds[1],
"options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
"inequality_constraints": inequality_constraints,
"equality_constraints": equality_constraints,
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
"fixed_features": fixed_features,
"timeout_sec": timeout_sec,
}

for i, batched_ics_ in enumerate(batched_ics):
# optimize using random restart optimization
Expand All @@ -285,7 +302,7 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
batch_acq_values = torch.cat(batch_acq_values_list)
return batch_candidates, batch_acq_values, opt_warnings

batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(timeout_sec)

optimization_warning_raised = any(
(issubclass(w.category, OptimizationWarning) for w in ws)
Expand Down Expand Up @@ -319,7 +336,9 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
**kwargs,
)

batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(
timeout_sec
)

optimization_warning_raised = any(
(issubclass(w.category, OptimizationWarning) for w in ws)
Expand Down
Loading

0 comments on commit 6bd5ca4

Please sign in to comment.