Skip to content

Commit

Permalink
Continuing cleanup of optimize_acqf (pytorch#1676)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1676

The usage of `kwargs` in `optimize_acqf` and related functions has a few downsides:
- It's not transparent to the user what options are supported
- When kwargs get mutated by `.pop()`, this can cause subtle errors
- If a user provides an unsupported option, there will be no error or warning.

In this diff I change some options in `kwargs` to named arguments. I did not entirely remove `kwargs`, because they serve a few purposes, all of which are kind of iffy. I'll attack that in the next diff in the stack. These purposes are:

- Some kwargs, now marked as `ic_gen_kwargs`, get passed from `optimize_acqf` to a function that generates ICs. This will only happen if the user provides a function with a different signature from the BoTorch IC generator functions. I'm not sure anyone was actually using that functionality. Tests still pass if the `kwargs` arguement in `optimize_acqf` is removed.
- Users may pass incorrect keyword arguments for no good reason. This fails silently.
- Ax `botorch_modular` passes the same signature to a variety of optimizers when the user does not specify otherwise. So it passes keyword arguments that don't do anything, but for a good reason. I think it would make sense to have BoTorch raise a warning and Ax silence it.

Differential Revision: https://internalfb.com/D43200823

fbshipit-source-id: b66ac13dba40ccd9e0b2c4e3ef10475d9b3a8167
  • Loading branch information
esantorella authored and facebook-github-bot committed Feb 14, 2023
1 parent ad38736 commit b1b20e4
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 44 deletions.
18 changes: 17 additions & 1 deletion botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import warnings
from math import ceil
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from botorch import settings
Expand Down Expand Up @@ -46,6 +46,22 @@
from torch.distributions import Normal
from torch.quasirandom import SobolEngine

TGenInitialConditions = Callable[
[
# reasoning behind this annotation: contravariance
qKnowledgeGradient,
Tensor,
int,
int,
int,
Optional[Dict[int, float]],
Optional[Dict[str, Union[bool, float, int]]],
Optional[List[Tuple[Tensor, Tensor, float]]],
Optional[List[Tuple[Tensor, Tensor, float]]],
],
Optional[Tensor],
]


def gen_batch_initial_conditions(
acq_function: AcquisitionFunction,
Expand Down
125 changes: 82 additions & 43 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from botorch.optim.initializers import (
gen_batch_initial_conditions,
gen_one_shot_kg_initial_conditions,
TGenInitialConditions,
)
from botorch.optim.stopping import ExpMAStoppingCriterion
from botorch.optim.utils import _filter_kwargs
Expand All @@ -53,7 +54,7 @@
}


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class OptimizeAcqfInputs:
"""
Container for inputs to `optimize_acqf`.
Expand All @@ -76,10 +77,19 @@ class OptimizeAcqfInputs:
return_best_only: bool
gen_candidates: TGenCandidates
sequential: bool
kwargs: Dict[str, Any]
ic_generator: Callable = dataclasses.field(init=False)
ic_generator: Optional[TGenInitialConditions] = None
timeout_sec: Optional[float] = None
return_full_tree: bool = False
ic_gen_kwargs: Dict = dataclasses.field(default_factory=dict)

@property
def full_tree(self) -> bool:
return (
isinstance(self.acq_function, OneShotAcquisitionFunction)
and not self.return_full_tree
)

def _validate(self) -> None:
def __post_init__(self) -> None:
if self.inequality_constraints is None and not (
self.bounds.ndim == 2 and self.bounds.shape[0] == 2
):
Expand Down Expand Up @@ -114,7 +124,7 @@ def _validate(self) -> None:
f"shape is {batch_initial_conditions_shape}."
)

elif "ic_generator" not in self.kwargs.keys():
elif not self.ic_generator:
if self.nonlinear_inequality_constraints:
raise RuntimeError(
"`ic_generator` must be given if "
Expand All @@ -137,14 +147,13 @@ def _validate(self) -> None:
"acquisition functions. Must have `sequential=False`."
)

def __post_init__(self) -> None:
self._validate()
if "ic_generator" in self.kwargs.keys():
self.ic_generator = self.kwargs.pop("ic_generator")
@property
def ic_gen(self) -> TGenInitialConditions:
if self.ic_generator:
return self.ic_generator
elif isinstance(self.acq_function, qKnowledgeGradient):
self.ic_generator = gen_one_shot_kg_initial_conditions
else:
self.ic_generator = gen_batch_initial_conditions
return gen_one_shot_kg_initial_conditions
return gen_batch_initial_conditions


def _optimize_acqf_all_features_fixed(
Expand All @@ -170,34 +179,32 @@ def _optimize_acqf_all_features_fixed(


def _optimize_acqf_sequential_q(
opt_inputs: OptimizeAcqfInputs,
timeout_sec: Optional[float],
start_time: float,
opt_inputs: OptimizeAcqfInputs, timeout_sec: Optional[float], start_time: float
) -> Tuple[Tensor, Tensor]:
"""
Helper function for `optimize_acqf` when sequential=True and q > 1.
"""
kwargs = opt_inputs.kwargs or {}
if timeout_sec is not None:
# When using sequential optimization, we allocate the total timeout
# evenly across the individual acquisition optimizations.
timeout_sec = (timeout_sec - start_time) / opt_inputs.q
kwargs["timeout_sec"] = timeout_sec

candidate_list, acq_value_list = [], []
base_X_pending = opt_inputs.acq_function.X_pending

new_inputs = dataclasses.replace(
opt_inputs,
q=1,
batch_initial_conditions=None,
return_best_only=True,
sequential=False,
timeout_sec=timeout_sec,
)
for i in range(opt_inputs.q):
kwargs["ic_generator"] = opt_inputs.ic_generator
new_inputs = dataclasses.replace(
opt_inputs,
q=1,
batch_initial_conditions=None,
return_best_only=True,
sequential=False,
kwargs=kwargs,

candidate, acq_value = _optimize_acqf_batch(
new_inputs, start_time=start_time, timeout_sec=timeout_sec
)
candidate, acq_value = _optimize_acqf(new_inputs)

candidate_list.append(candidate)
acq_value_list.append(acq_value)
Expand All @@ -217,17 +224,13 @@ def _optimize_acqf_batch(
) -> Tuple[Tensor, Tensor]:
options = opt_inputs.options or {}

kwargs = opt_inputs.kwargs
full_tree = isinstance(
opt_inputs.acq_function, OneShotAcquisitionFunction
) and not kwargs.pop("return_full_tree", False)

initial_conditions_provided = opt_inputs.batch_initial_conditions is not None

if initial_conditions_provided:
batch_initial_conditions = opt_inputs.batch_initial_conditions
else:
batch_initial_conditions = opt_inputs.ic_generator(
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
batch_initial_conditions = opt_inputs.ic_gen(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
Expand All @@ -237,7 +240,7 @@ def _optimize_acqf_batch(
options=options,
inequality_constraints=opt_inputs.inequality_constraints,
equality_constraints=opt_inputs.equality_constraints,
**kwargs,
**opt_inputs.ic_gen_kwargs,
)

batch_limit: int = options.get(
Expand Down Expand Up @@ -330,7 +333,7 @@ def _optimize_batch_candidates(
warnings.warn(first_warn_msg, RuntimeWarning)

if not initial_conditions_provided:
batch_initial_conditions = opt_inputs.ic_generator(
batch_initial_conditions = opt_inputs.ic_gen(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
Expand All @@ -340,7 +343,7 @@ def _optimize_batch_candidates(
options=options,
inequality_constraints=opt_inputs.inequality_constraints,
equality_constraints=opt_inputs.equality_constraints,
**kwargs,
**opt_inputs.ic_gen_kwargs,
)

batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(
Expand All @@ -365,7 +368,7 @@ def _optimize_batch_candidates(
batch_candidates = batch_candidates[best]
batch_acq_values = batch_acq_values[best]

if full_tree:
if opt_inputs.full_tree:
batch_candidates = opt_inputs.acq_function.extract_candidates(
X_full=batch_candidates
)
Expand All @@ -389,7 +392,11 @@ def optimize_acqf(
return_best_only: bool = True,
gen_candidates: Optional[TGenCandidates] = None,
sequential: bool = False,
**kwargs: Any,
*,
ic_generator: Optional[TGenInitialConditions] = None,
timeout_sec: Optional[float] = None,
return_full_tree: bool = False,
**ic_gen_kwargs: Any,
) -> Tuple[Tensor, Tensor]:
r"""Generate a set of candidates via multi-start optimization.
Expand Down Expand Up @@ -435,7 +442,15 @@ def optimize_acqf(
for method-specific inputs. Default: `gen_candidates_scipy`
sequential: If False, uses joint optimization, otherwise uses sequential
optimization.
kwargs: Additonal keyword arguments.
ic_generator: Function for generating initial conditions. Not needed when
`batch_initial_conditions` are provided. Defaults to
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
functions and `gen_batch_initial_conditions` otherwise. Must be specified
for nonlinear inequality constraints.
timeout_sec: Max amount of time optimization can run for.
return_full_tree:
ic_gen_kwargs: Additional keyword arguments passed to function specified by
`ic_generator`
Returns:
A two-element tuple containing
Expand Down Expand Up @@ -481,7 +496,10 @@ def optimize_acqf(
return_best_only=return_best_only,
gen_candidates=gen_candidates,
sequential=sequential,
kwargs=kwargs,
ic_generator=ic_generator,
timeout_sec=timeout_sec,
return_full_tree=return_full_tree,
ic_gen_kwargs=ic_gen_kwargs,
)
return _optimize_acqf(opt_acqf_inputs)

Expand All @@ -501,8 +519,7 @@ def _optimize_acqf(opt_inputs: OptimizeAcqfInputs) -> Tuple[Tensor, Tensor]:
)

start_time: float = time.monotonic()
kwargs = opt_inputs.kwargs
timeout_sec = kwargs.pop("timeout_sec", None)
timeout_sec = opt_inputs.timeout_sec

# Perform sequential optimization via successive conditioning on pending points
if opt_inputs.sequential and opt_inputs.q > 1:
Expand Down Expand Up @@ -531,7 +548,11 @@ def optimize_acqf_cyclic(
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
batch_initial_conditions: Optional[Tensor] = None,
cyclic_options: Optional[Dict[str, Union[bool, float, int, str]]] = None,
**kwargs,
*,
ic_generator: Optional[TGenInitialConditions] = None,
timeout_sec: Optional[float] = None,
return_full_tree: bool = False,
**ic_gen_kwargs: Any,
) -> Tuple[Tensor, Tensor]:
r"""Generate a set of `q` candidates via cyclic optimization.
Expand Down Expand Up @@ -561,6 +582,15 @@ def optimize_acqf_cyclic(
If no initial conditions are provided, the default initialization will
be used.
cyclic_options: Options for stopping criterion for outer cyclic optimization.
ic_generator: Function for generating initial conditions. Not needed when
`batch_initial_conditions` are provided. Defaults to
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
functions and `gen_batch_initial_conditions` otherwise. Must be specified
for nonlinear inequality constraints.
timeout_sec: Max amount of time optimization can run for.
return_full_tree:
ic_gen_kwargs: Additional keyword arguments passed to function specified by
`ic_generator`
Returns:
A two-element tuple containing
Expand Down Expand Up @@ -596,7 +626,10 @@ def optimize_acqf_cyclic(
return_best_only=True,
gen_candidates=gen_candidates_scipy,
sequential=True,
kwargs=kwargs,
ic_generator=ic_generator,
timeout_sec=timeout_sec,
return_full_tree=return_full_tree,
ic_gen_kwargs=ic_gen_kwargs,
)

# for the first cycle, optimize the q candidates sequentially
Expand Down Expand Up @@ -778,6 +811,8 @@ def optimize_acqf_mixed(
transformations).
batch_initial_conditions: A tensor to specify the initial conditions. Set
this if you do not want to use default initialization strategy.
kwargs: kwargs do nothing. This is provided so that the same arguments can
be passed to different acquisition functions without raising an error.
Returns:
A two-element tuple containing
Expand Down Expand Up @@ -881,6 +916,8 @@ def optimize_acqf_discrete(
a large training set.
unique: If True return unique choices, o/w choices may be repeated
(only relevant if `q > 1`).
kwargs: kwargs do nothing. This is provided so that the same arguments can
be passed to different acquisition functions without raising an error.
Returns:
A three-element tuple containing
Expand Down Expand Up @@ -1045,6 +1082,8 @@ def optimize_acqf_discrete_local_search(
a large training set.
unique: If True return unique choices, o/w choices may be repeated
(only relevant if `q > 1`).
kwargs: kwargs do nothing. This is provided so that the same arguments can
be passed to different acquisition functions without raising an error.
Returns:
A two-element tuple containing
Expand Down

0 comments on commit b1b20e4

Please sign in to comment.