Skip to content

Commit

Permalink
Warn if inoperable keyword arguments are passed to optimizers (pytorc…
Browse files Browse the repository at this point in the history
…h#1421)

Summary:
X-link: facebook/Ax#1421

Pull Request resolved: pytorch#1677

The previous diff discusses issues with kwargs in BoTorch optimizers and resolves some of them. This diff warns if keyword arguments that do nothing are passed to an optimizer, and stops Ax (botorch-modular) from doing that.

Reviewed By: lena-kashtelyan

Differential Revision: D43277102

fbshipit-source-id: f22b4cd1c4b6227a5210f27d1501c3667209a1d7
  • Loading branch information
esantorella authored and facebook-github-bot committed Feb 14, 2023
1 parent b1b20e4 commit c728104
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
21 changes: 21 additions & 0 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,24 @@ def ic_gen(self) -> TGenInitialConditions:
return gen_batch_initial_conditions


def _raise_deprecation_warning_if_kwargs(fn_name: str, kwargs: Dict[str, Any]) -> None:
"""
Raise a warning if kwargs are provided.
Some functions used to support **kwargs. The applicable parameters have now been
refactored to be named arguments, so no warning will be raised for users passing
the expected arguments. However, if a user had been passing an inapplicable
keyword argument, this will now raise a warning whereas in the past it did
nothing.
"""
if len(kwargs) > 0:
warnings.warn(
f"`{fn_name}` does not support arguments {list(kwargs.keys())}. In "
"the future, this will become an error.",
DeprecationWarning,
)


def _optimize_acqf_all_features_fixed(
*,
bounds: Tensor,
Expand Down Expand Up @@ -830,6 +848,7 @@ def optimize_acqf_mixed(
"are currently not supported when `q > 1`. This is needed to "
"compute the joint acquisition value."
)
_raise_deprecation_warning_if_kwargs("optimize_acqf_mixed", kwargs)

if q == 1:
ff_candidate_list, ff_acq_value_list = [], []
Expand Down Expand Up @@ -932,6 +951,7 @@ def optimize_acqf_discrete(
)
if choices.numel() == 0:
raise InputDataError("`choices` must be non-emtpy.")
_raise_deprecation_warning_if_kwargs("optimize_acqf_discrete", kwargs)
choices_batched = choices.unsqueeze(-2)
if q > 1:
candidate_list, acq_value_list = [], []
Expand Down Expand Up @@ -1091,6 +1111,7 @@ def optimize_acqf_discrete_local_search(
- a `q x d`-dim tensor of generated candidates.
- an associated acquisition value.
"""
_raise_deprecation_warning_if_kwargs("optimize_acqf_discrete_local_search", kwargs)
candidate_list = []
base_X_pending = acq_function.X_pending if q > 1 else None
base_X_avoid = X_avoid
Expand Down
12 changes: 11 additions & 1 deletion test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,6 @@ def test_optimize_acqf_discrete(self):

mock_acq_function = SquaredAcquisitionFunction()
mock_acq_function.set_X_pending(None)

# ensure proper raising of errors if no choices
with self.assertRaisesRegex(InputDataError, "`choices` must be non-emtpy."):
optimize_acqf_discrete(
Expand All @@ -1404,6 +1403,17 @@ def test_optimize_acqf_discrete(self):
)

choices = torch.rand(5, 2, **tkwargs)

# warning for unsupported keyword arguments
with self.assertWarnsRegex(
DeprecationWarning,
r"`optimize_acqf_discrete` does not support arguments "
r"\['num_restarts'\]. In the future, this will become an error.",
):
optimize_acqf_discrete(
acq_function=mock_acq_function, q=q, choices=choices, num_restarts=8
)

exp_acq_vals = mock_acq_function(choices)

# test unique
Expand Down

0 comments on commit c728104

Please sign in to comment.