From 538d443bc7d43063ecaaa09f1f04106e8945d6c2 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Tue, 14 Feb 2023 11:47:08 -0800 Subject: [PATCH] Continuing cleanup of optimize_acqf (#1676) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1676 The usage of `kwargs` in `optimize_acqf` and related functions has a few downsides: - It's not transparent to the user what options are supported - When kwargs get mutated by `.pop()`, this can cause subtle errors - If a user provides an unsupported option, there will be no error or warning. In this diff I change some options in `kwargs` to named arguments. I did not entirely remove `kwargs`, because they serve a few purposes, all of which are kind of iffy. I'll attack that in the next diff in the stack. These purposes are: - Some kwargs, now marked as `ic_gen_kwargs`, get passed from `optimize_acqf` to a function that generates ICs. This will only happen if the user provides a function with a different signature from the BoTorch IC generator functions. I'm not sure anyone was actually using that functionality. Tests still pass if the `kwargs` arguement in `optimize_acqf` is removed. - Users may pass incorrect keyword arguments for no good reason. This fails silently. - Ax `botorch_modular` passes the same signature to a variety of optimizers when the user does not specify otherwise. So it passes keyword arguments that don't do anything, but for a good reason. I think it would make sense to have BoTorch raise a warning and Ax silence it. Differential Revision: D43200823 fbshipit-source-id: c7679451da10b5934bf3c55e4909d418cfa72a70 --- botorch/optim/initializers.py | 18 ++++- botorch/optim/optimize.py | 125 ++++++++++++++++++++++------------ 2 files changed, 99 insertions(+), 44 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index 906bcf7b21..8b392019c9 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -16,7 +16,7 @@ import warnings from math import ceil -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from botorch import settings @@ -46,6 +46,22 @@ from torch.distributions import Normal from torch.quasirandom import SobolEngine +TGenInitialConditions = Callable[ + [ + # reasoning behind this annotation: contravariance + qKnowledgeGradient, + Tensor, + int, + int, + int, + Optional[Dict[int, float]], + Optional[Dict[str, Union[bool, float, int]]], + Optional[List[Tuple[Tensor, Tensor, float]]], + Optional[List[Tuple[Tensor, Tensor, float]]], + ], + Optional[Tensor], +] + def gen_batch_initial_conditions( acq_function: AcquisitionFunction, diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index d91eb06d34..ae745be1e6 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -30,6 +30,7 @@ from botorch.optim.initializers import ( gen_batch_initial_conditions, gen_one_shot_kg_initial_conditions, + TGenInitialConditions, ) from botorch.optim.stopping import ExpMAStoppingCriterion from botorch.optim.utils import _filter_kwargs @@ -53,7 +54,7 @@ } -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class OptimizeAcqfInputs: """ Container for inputs to `optimize_acqf`. @@ -76,10 +77,19 @@ class OptimizeAcqfInputs: return_best_only: bool gen_candidates: TGenCandidates sequential: bool - kwargs: Dict[str, Any] - ic_generator: Callable = dataclasses.field(init=False) + ic_generator: Optional[TGenInitialConditions] = None + timeout_sec: Optional[float] = None + return_full_tree: bool = False + ic_gen_kwargs: Dict = dataclasses.field(default_factory=dict) + + @property + def full_tree(self) -> bool: + return ( + isinstance(self.acq_function, OneShotAcquisitionFunction) + and not self.return_full_tree + ) - def _validate(self) -> None: + def __post_init__(self) -> None: if self.inequality_constraints is None and not ( self.bounds.ndim == 2 and self.bounds.shape[0] == 2 ): @@ -114,7 +124,7 @@ def _validate(self) -> None: f"shape is {batch_initial_conditions_shape}." ) - elif "ic_generator" not in self.kwargs.keys(): + elif not self.ic_generator: if self.nonlinear_inequality_constraints: raise RuntimeError( "`ic_generator` must be given if " @@ -137,14 +147,13 @@ def _validate(self) -> None: "acquisition functions. Must have `sequential=False`." ) - def __post_init__(self) -> None: - self._validate() - if "ic_generator" in self.kwargs.keys(): - self.ic_generator = self.kwargs.pop("ic_generator") + @property + def ic_gen(self) -> TGenInitialConditions: + if self.ic_generator: + return self.ic_generator elif isinstance(self.acq_function, qKnowledgeGradient): - self.ic_generator = gen_one_shot_kg_initial_conditions - else: - self.ic_generator = gen_batch_initial_conditions + return gen_one_shot_kg_initial_conditions + return gen_batch_initial_conditions def _optimize_acqf_all_features_fixed( @@ -170,34 +179,32 @@ def _optimize_acqf_all_features_fixed( def _optimize_acqf_sequential_q( - opt_inputs: OptimizeAcqfInputs, - timeout_sec: Optional[float], - start_time: float, + opt_inputs: OptimizeAcqfInputs, timeout_sec: Optional[float], start_time: float ) -> Tuple[Tensor, Tensor]: """ Helper function for `optimize_acqf` when sequential=True and q > 1. """ - kwargs = opt_inputs.kwargs or {} if timeout_sec is not None: # When using sequential optimization, we allocate the total timeout # evenly across the individual acquisition optimizations. timeout_sec = (timeout_sec - start_time) / opt_inputs.q - kwargs["timeout_sec"] = timeout_sec candidate_list, acq_value_list = [], [] base_X_pending = opt_inputs.acq_function.X_pending + new_inputs = dataclasses.replace( + opt_inputs, + q=1, + batch_initial_conditions=None, + return_best_only=True, + sequential=False, + timeout_sec=timeout_sec, + ) for i in range(opt_inputs.q): - kwargs["ic_generator"] = opt_inputs.ic_generator - new_inputs = dataclasses.replace( - opt_inputs, - q=1, - batch_initial_conditions=None, - return_best_only=True, - sequential=False, - kwargs=kwargs, + + candidate, acq_value = _optimize_acqf_batch( + new_inputs, start_time=start_time, timeout_sec=timeout_sec ) - candidate, acq_value = _optimize_acqf(new_inputs) candidate_list.append(candidate) acq_value_list.append(acq_value) @@ -217,17 +224,13 @@ def _optimize_acqf_batch( ) -> Tuple[Tensor, Tensor]: options = opt_inputs.options or {} - kwargs = opt_inputs.kwargs - full_tree = isinstance( - opt_inputs.acq_function, OneShotAcquisitionFunction - ) and not kwargs.pop("return_full_tree", False) - initial_conditions_provided = opt_inputs.batch_initial_conditions is not None if initial_conditions_provided: batch_initial_conditions = opt_inputs.batch_initial_conditions else: - batch_initial_conditions = opt_inputs.ic_generator( + # pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call. + batch_initial_conditions = opt_inputs.ic_gen( acq_function=opt_inputs.acq_function, bounds=opt_inputs.bounds, q=opt_inputs.q, @@ -237,7 +240,7 @@ def _optimize_acqf_batch( options=options, inequality_constraints=opt_inputs.inequality_constraints, equality_constraints=opt_inputs.equality_constraints, - **kwargs, + **opt_inputs.ic_gen_kwargs, ) batch_limit: int = options.get( @@ -330,7 +333,7 @@ def _optimize_batch_candidates( warnings.warn(first_warn_msg, RuntimeWarning) if not initial_conditions_provided: - batch_initial_conditions = opt_inputs.ic_generator( + batch_initial_conditions = opt_inputs.ic_gen( acq_function=opt_inputs.acq_function, bounds=opt_inputs.bounds, q=opt_inputs.q, @@ -340,7 +343,7 @@ def _optimize_batch_candidates( options=options, inequality_constraints=opt_inputs.inequality_constraints, equality_constraints=opt_inputs.equality_constraints, - **kwargs, + **opt_inputs.ic_gen_kwargs, ) batch_candidates, batch_acq_values, ws = _optimize_batch_candidates( @@ -365,7 +368,7 @@ def _optimize_batch_candidates( batch_candidates = batch_candidates[best] batch_acq_values = batch_acq_values[best] - if full_tree: + if opt_inputs.full_tree: batch_candidates = opt_inputs.acq_function.extract_candidates( X_full=batch_candidates ) @@ -389,7 +392,11 @@ def optimize_acqf( return_best_only: bool = True, gen_candidates: Optional[TGenCandidates] = None, sequential: bool = False, - **kwargs: Any, + *, + ic_generator: Optional[TGenInitialConditions] = None, + timeout_sec: Optional[float] = None, + return_full_tree: bool = False, + **ic_gen_kwargs: Any, ) -> Tuple[Tensor, Tensor]: r"""Generate a set of candidates via multi-start optimization. @@ -435,7 +442,15 @@ def optimize_acqf( for method-specific inputs. Default: `gen_candidates_scipy` sequential: If False, uses joint optimization, otherwise uses sequential optimization. - kwargs: Additonal keyword arguments. + ic_generator: Function for generating initial conditions. Not needed when + `batch_initial_conditions` are provided. Defaults to + `gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition + functions and `gen_batch_initial_conditions` otherwise. Must be specified + for nonlinear inequality constraints. + timeout_sec: Max amount of time optimization can run for. + return_full_tree: + ic_gen_kwargs: Additional keyword arguments passed to function specified by + `ic_generator` Returns: A two-element tuple containing @@ -481,7 +496,10 @@ def optimize_acqf( return_best_only=return_best_only, gen_candidates=gen_candidates, sequential=sequential, - kwargs=kwargs, + ic_generator=ic_generator, + timeout_sec=timeout_sec, + return_full_tree=return_full_tree, + ic_gen_kwargs=ic_gen_kwargs, ) return _optimize_acqf(opt_acqf_inputs) @@ -501,8 +519,7 @@ def _optimize_acqf(opt_inputs: OptimizeAcqfInputs) -> Tuple[Tensor, Tensor]: ) start_time: float = time.monotonic() - kwargs = opt_inputs.kwargs - timeout_sec = kwargs.pop("timeout_sec", None) + timeout_sec = opt_inputs.timeout_sec # Perform sequential optimization via successive conditioning on pending points if opt_inputs.sequential and opt_inputs.q > 1: @@ -531,7 +548,11 @@ def optimize_acqf_cyclic( post_processing_func: Optional[Callable[[Tensor], Tensor]] = None, batch_initial_conditions: Optional[Tensor] = None, cyclic_options: Optional[Dict[str, Union[bool, float, int, str]]] = None, - **kwargs, + *, + ic_generator: Optional[TGenInitialConditions] = None, + timeout_sec: Optional[float] = None, + return_full_tree: bool = False, + **ic_gen_kwargs: Any, ) -> Tuple[Tensor, Tensor]: r"""Generate a set of `q` candidates via cyclic optimization. @@ -561,6 +582,15 @@ def optimize_acqf_cyclic( If no initial conditions are provided, the default initialization will be used. cyclic_options: Options for stopping criterion for outer cyclic optimization. + ic_generator: Function for generating initial conditions. Not needed when + `batch_initial_conditions` are provided. Defaults to + `gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition + functions and `gen_batch_initial_conditions` otherwise. Must be specified + for nonlinear inequality constraints. + timeout_sec: Max amount of time optimization can run for. + return_full_tree: + ic_gen_kwargs: Additional keyword arguments passed to function specified by + `ic_generator` Returns: A two-element tuple containing @@ -596,7 +626,10 @@ def optimize_acqf_cyclic( return_best_only=True, gen_candidates=gen_candidates_scipy, sequential=True, - kwargs=kwargs, + ic_generator=ic_generator, + timeout_sec=timeout_sec, + return_full_tree=return_full_tree, + ic_gen_kwargs=ic_gen_kwargs, ) # for the first cycle, optimize the q candidates sequentially @@ -778,6 +811,8 @@ def optimize_acqf_mixed( transformations). batch_initial_conditions: A tensor to specify the initial conditions. Set this if you do not want to use default initialization strategy. + kwargs: kwargs do nothing. This is provided so that the same arguments can + be passed to different acquisition functions without raising an error. Returns: A two-element tuple containing @@ -881,6 +916,8 @@ def optimize_acqf_discrete( a large training set. unique: If True return unique choices, o/w choices may be repeated (only relevant if `q > 1`). + kwargs: kwargs do nothing. This is provided so that the same arguments can + be passed to different acquisition functions without raising an error. Returns: A three-element tuple containing @@ -1045,6 +1082,8 @@ def optimize_acqf_discrete_local_search( a large training set. unique: If True return unique choices, o/w choices may be repeated (only relevant if `q > 1`). + kwargs: kwargs do nothing. This is provided so that the same arguments can + be passed to different acquisition functions without raising an error. Returns: A two-element tuple containing