Skip to content

Commit

Permalink
Suppress non-log EI warnings in HVKG helpers
Browse files Browse the repository at this point in the history
Summary: My first idea was to change this to use Log-EHVI, but this helper is used in the tutorial to compute the hypervolume, so we can't change the return value. Since this is not directly constructed by the user, raising a warning is not productive. This diff suppresses the warning.

Differential Revision: D61722790
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Aug 23, 2024
1 parent d03b4ed commit 2592514
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Learning, 2023.
"""

import warnings
from copy import deepcopy
from typing import Any, Callable, Optional

Expand All @@ -35,6 +36,7 @@
)
from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import NumericsWarning
from botorch.models.deterministic import PosteriorMeanModel
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
Expand Down Expand Up @@ -500,20 +502,22 @@ def _get_hv_value_function(
if use_posterior_mean:
model = PosteriorMeanModel(model=model)
sampler = StochasticSampler(sample_shape=torch.Size([1])) # dummy sampler
base_value_function = qExpectedHypervolumeImprovement(
model=model,
ref_point=ref_point,
partitioning=FastNondominatedPartitioning(
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=NumericsWarning)
base_value_function = qExpectedHypervolumeImprovement(
model=model,
ref_point=ref_point,
Y=torch.empty(
(0, ref_point.shape[0]),
dtype=ref_point.dtype,
device=ref_point.device,
),
), # create empty partitioning
sampler=sampler,
objective=objective,
)
partitioning=FastNondominatedPartitioning(
ref_point=ref_point,
Y=torch.empty(
(0, ref_point.shape[0]),
dtype=ref_point.dtype,
device=ref_point.device,
),
), # create empty partitioning
sampler=sampler,
objective=objective,
)
# ProjectedAcquisitionFunction requires this
base_value_function.posterior_transform = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from itertools import product
from unittest import mock

Expand All @@ -22,6 +23,7 @@
IdentityMCMultiOutputObjective,
)
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import NumericsWarning
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
Expand Down Expand Up @@ -366,24 +368,26 @@ def test_evaluate_q_hvkg(self):
for use_posterior_mean in (True, False):
with mock.patch.object(
ModelListGP, "fantasize", return_value=mfm
) as patch_f:
with mock.patch(
NO, new_callable=mock.PropertyMock
) as mock_num_outputs:
mock_num_outputs.return_value = 3
qHVKG = acqf_class(
model=model,
num_fantasies=n_f,
objective=objective,
ref_point=ref_point,
num_pareto=num_pareto,
use_posterior_mean=use_posterior_mean,
**mf_kwargs,
)
val = qHVKG(X)
patch_f.assert_called_once()
cargs, ckwargs = patch_f.call_args
self.assertEqual(ckwargs["X"].shape, torch.Size([1, 1, 1]))
) as patch_f, mock.patch(
NO, new_callable=mock.PropertyMock
) as mock_num_outputs, warnings.catch_warnings(
record=True
) as ws:
mock_num_outputs.return_value = 3
qHVKG = acqf_class(
model=model,
num_fantasies=n_f,
objective=objective,
ref_point=ref_point,
num_pareto=num_pareto,
use_posterior_mean=use_posterior_mean,
**mf_kwargs,
)
val = qHVKG(X)
patch_f.assert_called_once()
cargs, ckwargs = patch_f.call_args
self.assertEqual(ckwargs["X"].shape, torch.Size([1, 1, 1]))
self.assertFalse(any(w.category is NumericsWarning for w in ws))
Ys = mean if use_posterior_mean else samples
objs = objective(Ys.squeeze(1)).view(-1, num_pareto, num_objectives)
if num_objectives == 2:
Expand Down

0 comments on commit 2592514

Please sign in to comment.