Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn if inoperable keyword arguments are passed to optimizers #1677

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import warnings
from math import ceil
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from botorch import settings
Expand Down Expand Up @@ -46,6 +46,22 @@
from torch.distributions import Normal
from torch.quasirandom import SobolEngine

TGenInitialConditions = Callable[
[
# reasoning behind this annotation: contravariance
qKnowledgeGradient,
Tensor,
int,
int,
int,
Optional[Dict[int, float]],
Optional[Dict[str, Union[bool, float, int]]],
Optional[List[Tuple[Tensor, Tensor, float]]],
Optional[List[Tuple[Tensor, Tensor, float]]],
],
Optional[Tensor],
]


def gen_batch_initial_conditions(
acq_function: AcquisitionFunction,
Expand Down
146 changes: 103 additions & 43 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from botorch.optim.initializers import (
gen_batch_initial_conditions,
gen_one_shot_kg_initial_conditions,
TGenInitialConditions,
)
from botorch.optim.stopping import ExpMAStoppingCriterion
from botorch.optim.utils import _filter_kwargs
Expand All @@ -53,7 +54,7 @@
}


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class OptimizeAcqfInputs:
"""
Container for inputs to `optimize_acqf`.
Expand All @@ -76,10 +77,19 @@ class OptimizeAcqfInputs:
return_best_only: bool
gen_candidates: TGenCandidates
sequential: bool
kwargs: Dict[str, Any]
ic_generator: Callable = dataclasses.field(init=False)
ic_generator: Optional[TGenInitialConditions] = None
timeout_sec: Optional[float] = None
return_full_tree: bool = False
ic_gen_kwargs: Dict = dataclasses.field(default_factory=dict)

@property
def full_tree(self) -> bool:
return (
isinstance(self.acq_function, OneShotAcquisitionFunction)
and not self.return_full_tree
)

def _validate(self) -> None:
def __post_init__(self) -> None:
if self.inequality_constraints is None and not (
self.bounds.ndim == 2 and self.bounds.shape[0] == 2
):
Expand Down Expand Up @@ -114,7 +124,7 @@ def _validate(self) -> None:
f"shape is {batch_initial_conditions_shape}."
)

elif "ic_generator" not in self.kwargs.keys():
elif not self.ic_generator:
if self.nonlinear_inequality_constraints:
raise RuntimeError(
"`ic_generator` must be given if "
Expand All @@ -137,14 +147,31 @@ def _validate(self) -> None:
"acquisition functions. Must have `sequential=False`."
)

def __post_init__(self) -> None:
self._validate()
if "ic_generator" in self.kwargs.keys():
self.ic_generator = self.kwargs.pop("ic_generator")
@property
def ic_gen(self) -> TGenInitialConditions:
if self.ic_generator:
return self.ic_generator
elif isinstance(self.acq_function, qKnowledgeGradient):
self.ic_generator = gen_one_shot_kg_initial_conditions
else:
self.ic_generator = gen_batch_initial_conditions
return gen_one_shot_kg_initial_conditions
return gen_batch_initial_conditions


def _raise_deprecation_warning_if_kwargs(fn_name: str, kwargs: Dict[str, Any]) -> None:
"""
Raise a warning if kwargs are provided.

Some functions used to support **kwargs. The applicable parameters have now been
refactored to be named arguments, so no warning will be raised for users passing
the expected arguments. However, if a user had been passing an inapplicable
keyword argument, this will now raise a warning whereas in the past it did
nothing.
"""
if len(kwargs) > 0:
warnings.warn(
f"`{fn_name}` does not support arguments {list(kwargs.keys())}. In "
"the future, this will become an error.",
DeprecationWarning,
)


def _optimize_acqf_all_features_fixed(
Expand All @@ -170,34 +197,32 @@ def _optimize_acqf_all_features_fixed(


def _optimize_acqf_sequential_q(
opt_inputs: OptimizeAcqfInputs,
timeout_sec: Optional[float],
start_time: float,
opt_inputs: OptimizeAcqfInputs, timeout_sec: Optional[float], start_time: float
) -> Tuple[Tensor, Tensor]:
"""
Helper function for `optimize_acqf` when sequential=True and q > 1.
"""
kwargs = opt_inputs.kwargs or {}
if timeout_sec is not None:
# When using sequential optimization, we allocate the total timeout
# evenly across the individual acquisition optimizations.
timeout_sec = (timeout_sec - start_time) / opt_inputs.q
kwargs["timeout_sec"] = timeout_sec

candidate_list, acq_value_list = [], []
base_X_pending = opt_inputs.acq_function.X_pending

new_inputs = dataclasses.replace(
opt_inputs,
q=1,
batch_initial_conditions=None,
return_best_only=True,
sequential=False,
timeout_sec=timeout_sec,
)
for i in range(opt_inputs.q):
kwargs["ic_generator"] = opt_inputs.ic_generator
new_inputs = dataclasses.replace(
opt_inputs,
q=1,
batch_initial_conditions=None,
return_best_only=True,
sequential=False,
kwargs=kwargs,

candidate, acq_value = _optimize_acqf_batch(
new_inputs, start_time=start_time, timeout_sec=timeout_sec
)
candidate, acq_value = _optimize_acqf(new_inputs)

candidate_list.append(candidate)
acq_value_list.append(acq_value)
Expand All @@ -217,17 +242,13 @@ def _optimize_acqf_batch(
) -> Tuple[Tensor, Tensor]:
options = opt_inputs.options or {}

kwargs = opt_inputs.kwargs
full_tree = isinstance(
opt_inputs.acq_function, OneShotAcquisitionFunction
) and not kwargs.pop("return_full_tree", False)

initial_conditions_provided = opt_inputs.batch_initial_conditions is not None

if initial_conditions_provided:
batch_initial_conditions = opt_inputs.batch_initial_conditions
else:
batch_initial_conditions = opt_inputs.ic_generator(
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
batch_initial_conditions = opt_inputs.ic_gen(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
Expand All @@ -237,7 +258,7 @@ def _optimize_acqf_batch(
options=options,
inequality_constraints=opt_inputs.inequality_constraints,
equality_constraints=opt_inputs.equality_constraints,
**kwargs,
**opt_inputs.ic_gen_kwargs,
)

batch_limit: int = options.get(
Expand Down Expand Up @@ -330,7 +351,7 @@ def _optimize_batch_candidates(
warnings.warn(first_warn_msg, RuntimeWarning)

if not initial_conditions_provided:
batch_initial_conditions = opt_inputs.ic_generator(
batch_initial_conditions = opt_inputs.ic_gen(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
Expand All @@ -340,7 +361,7 @@ def _optimize_batch_candidates(
options=options,
inequality_constraints=opt_inputs.inequality_constraints,
equality_constraints=opt_inputs.equality_constraints,
**kwargs,
**opt_inputs.ic_gen_kwargs,
)

batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(
Expand All @@ -365,7 +386,7 @@ def _optimize_batch_candidates(
batch_candidates = batch_candidates[best]
batch_acq_values = batch_acq_values[best]

if full_tree:
if opt_inputs.full_tree:
batch_candidates = opt_inputs.acq_function.extract_candidates(
X_full=batch_candidates
)
Expand All @@ -389,7 +410,11 @@ def optimize_acqf(
return_best_only: bool = True,
gen_candidates: Optional[TGenCandidates] = None,
sequential: bool = False,
**kwargs: Any,
*,
ic_generator: Optional[TGenInitialConditions] = None,
timeout_sec: Optional[float] = None,
return_full_tree: bool = False,
**ic_gen_kwargs: Any,
) -> Tuple[Tensor, Tensor]:
r"""Generate a set of candidates via multi-start optimization.

Expand Down Expand Up @@ -435,7 +460,15 @@ def optimize_acqf(
for method-specific inputs. Default: `gen_candidates_scipy`
sequential: If False, uses joint optimization, otherwise uses sequential
optimization.
kwargs: Additonal keyword arguments.
ic_generator: Function for generating initial conditions. Not needed when
`batch_initial_conditions` are provided. Defaults to
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
functions and `gen_batch_initial_conditions` otherwise. Must be specified
for nonlinear inequality constraints.
timeout_sec: Max amount of time optimization can run for.
return_full_tree:
ic_gen_kwargs: Additional keyword arguments passed to function specified by
`ic_generator`

Returns:
A two-element tuple containing
Expand Down Expand Up @@ -481,7 +514,10 @@ def optimize_acqf(
return_best_only=return_best_only,
gen_candidates=gen_candidates,
sequential=sequential,
kwargs=kwargs,
ic_generator=ic_generator,
timeout_sec=timeout_sec,
return_full_tree=return_full_tree,
ic_gen_kwargs=ic_gen_kwargs,
)
return _optimize_acqf(opt_acqf_inputs)

Expand All @@ -501,8 +537,7 @@ def _optimize_acqf(opt_inputs: OptimizeAcqfInputs) -> Tuple[Tensor, Tensor]:
)

start_time: float = time.monotonic()
kwargs = opt_inputs.kwargs
timeout_sec = kwargs.pop("timeout_sec", None)
timeout_sec = opt_inputs.timeout_sec

# Perform sequential optimization via successive conditioning on pending points
if opt_inputs.sequential and opt_inputs.q > 1:
Expand Down Expand Up @@ -531,7 +566,11 @@ def optimize_acqf_cyclic(
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
batch_initial_conditions: Optional[Tensor] = None,
cyclic_options: Optional[Dict[str, Union[bool, float, int, str]]] = None,
**kwargs,
*,
ic_generator: Optional[TGenInitialConditions] = None,
timeout_sec: Optional[float] = None,
return_full_tree: bool = False,
**ic_gen_kwargs: Any,
) -> Tuple[Tensor, Tensor]:
r"""Generate a set of `q` candidates via cyclic optimization.

Expand Down Expand Up @@ -561,6 +600,15 @@ def optimize_acqf_cyclic(
If no initial conditions are provided, the default initialization will
be used.
cyclic_options: Options for stopping criterion for outer cyclic optimization.
ic_generator: Function for generating initial conditions. Not needed when
`batch_initial_conditions` are provided. Defaults to
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
functions and `gen_batch_initial_conditions` otherwise. Must be specified
for nonlinear inequality constraints.
timeout_sec: Max amount of time optimization can run for.
return_full_tree:
ic_gen_kwargs: Additional keyword arguments passed to function specified by
`ic_generator`

Returns:
A two-element tuple containing
Expand Down Expand Up @@ -596,7 +644,10 @@ def optimize_acqf_cyclic(
return_best_only=True,
gen_candidates=gen_candidates_scipy,
sequential=True,
kwargs=kwargs,
ic_generator=ic_generator,
timeout_sec=timeout_sec,
return_full_tree=return_full_tree,
ic_gen_kwargs=ic_gen_kwargs,
)

# for the first cycle, optimize the q candidates sequentially
Expand Down Expand Up @@ -778,6 +829,8 @@ def optimize_acqf_mixed(
transformations).
batch_initial_conditions: A tensor to specify the initial conditions. Set
this if you do not want to use default initialization strategy.
kwargs: kwargs do nothing. This is provided so that the same arguments can
be passed to different acquisition functions without raising an error.

Returns:
A two-element tuple containing
Expand All @@ -795,6 +848,7 @@ def optimize_acqf_mixed(
"are currently not supported when `q > 1`. This is needed to "
"compute the joint acquisition value."
)
_raise_deprecation_warning_if_kwargs("optimize_acqf_mixed", kwargs)

if q == 1:
ff_candidate_list, ff_acq_value_list = [], []
Expand Down Expand Up @@ -881,6 +935,8 @@ def optimize_acqf_discrete(
a large training set.
unique: If True return unique choices, o/w choices may be repeated
(only relevant if `q > 1`).
kwargs: kwargs do nothing. This is provided so that the same arguments can
be passed to different acquisition functions without raising an error.

Returns:
A three-element tuple containing
Expand All @@ -895,6 +951,7 @@ def optimize_acqf_discrete(
)
if choices.numel() == 0:
raise InputDataError("`choices` must be non-emtpy.")
_raise_deprecation_warning_if_kwargs("optimize_acqf_discrete", kwargs)
choices_batched = choices.unsqueeze(-2)
if q > 1:
candidate_list, acq_value_list = [], []
Expand Down Expand Up @@ -1045,13 +1102,16 @@ def optimize_acqf_discrete_local_search(
a large training set.
unique: If True return unique choices, o/w choices may be repeated
(only relevant if `q > 1`).
kwargs: kwargs do nothing. This is provided so that the same arguments can
be passed to different acquisition functions without raising an error.

Returns:
A two-element tuple containing

- a `q x d`-dim tensor of generated candidates.
- an associated acquisition value.
"""
_raise_deprecation_warning_if_kwargs("optimize_acqf_discrete_local_search", kwargs)
candidate_list = []
base_X_pending = acq_function.X_pending if q > 1 else None
base_X_avoid = X_avoid
Expand Down
12 changes: 11 additions & 1 deletion test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,6 @@ def test_optimize_acqf_discrete(self):

mock_acq_function = SquaredAcquisitionFunction()
mock_acq_function.set_X_pending(None)

# ensure proper raising of errors if no choices
with self.assertRaisesRegex(InputDataError, "`choices` must be non-emtpy."):
optimize_acqf_discrete(
Expand All @@ -1404,6 +1403,17 @@ def test_optimize_acqf_discrete(self):
)

choices = torch.rand(5, 2, **tkwargs)

# warning for unsupported keyword arguments
with self.assertWarnsRegex(
DeprecationWarning,
r"`optimize_acqf_discrete` does not support arguments "
r"\['num_restarts'\]. In the future, this will become an error.",
):
optimize_acqf_discrete(
acq_function=mock_acq_function, q=q, choices=choices, num_restarts=8
)

exp_acq_vals = mock_acq_function(choices)

# test unique
Expand Down