diff --git a/botorch/acquisition/multi_objective/hypervolume_knowledge_gradient.py b/botorch/acquisition/multi_objective/hypervolume_knowledge_gradient.py index 8b017654b1..edd37acf5a 100644 --- a/botorch/acquisition/multi_objective/hypervolume_knowledge_gradient.py +++ b/botorch/acquisition/multi_objective/hypervolume_knowledge_gradient.py @@ -16,6 +16,7 @@ Learning, 2023. """ +import warnings from copy import deepcopy from typing import Any, Callable, Optional @@ -35,6 +36,7 @@ ) from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective from botorch.exceptions.errors import UnsupportedError +from botorch.exceptions.warnings import NumericsWarning from botorch.models.deterministic import PosteriorMeanModel from botorch.models.model import Model from botorch.sampling.base import MCSampler @@ -500,20 +502,22 @@ def _get_hv_value_function( if use_posterior_mean: model = PosteriorMeanModel(model=model) sampler = StochasticSampler(sample_shape=torch.Size([1])) # dummy sampler - base_value_function = qExpectedHypervolumeImprovement( - model=model, - ref_point=ref_point, - partitioning=FastNondominatedPartitioning( + with warnings.catch_warnings(): + warnings.simplefilter(action="ignore", category=NumericsWarning) + base_value_function = qExpectedHypervolumeImprovement( + model=model, ref_point=ref_point, - Y=torch.empty( - (0, ref_point.shape[0]), - dtype=ref_point.dtype, - device=ref_point.device, - ), - ), # create empty partitioning - sampler=sampler, - objective=objective, - ) + partitioning=FastNondominatedPartitioning( + ref_point=ref_point, + Y=torch.empty( + (0, ref_point.shape[0]), + dtype=ref_point.dtype, + device=ref_point.device, + ), + ), # create empty partitioning + sampler=sampler, + objective=objective, + ) # ProjectedAcquisitionFunction requires this base_value_function.posterior_transform = None diff --git a/test/acquisition/multi_objective/test_hypervolume_knowledge_gradient.py b/test/acquisition/multi_objective/test_hypervolume_knowledge_gradient.py index a6d22a8550..04987b860e 100644 --- a/test/acquisition/multi_objective/test_hypervolume_knowledge_gradient.py +++ b/test/acquisition/multi_objective/test_hypervolume_knowledge_gradient.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import warnings from itertools import product from unittest import mock @@ -22,6 +23,7 @@ IdentityMCMultiOutputObjective, ) from botorch.exceptions.errors import UnsupportedError +from botorch.exceptions.warnings import NumericsWarning from botorch.models.deterministic import GenericDeterministicModel from botorch.models.gp_regression import SingleTaskGP from botorch.models.model_list_gp_regression import ModelListGP @@ -366,24 +368,26 @@ def test_evaluate_q_hvkg(self): for use_posterior_mean in (True, False): with mock.patch.object( ModelListGP, "fantasize", return_value=mfm - ) as patch_f: - with mock.patch( - NO, new_callable=mock.PropertyMock - ) as mock_num_outputs: - mock_num_outputs.return_value = 3 - qHVKG = acqf_class( - model=model, - num_fantasies=n_f, - objective=objective, - ref_point=ref_point, - num_pareto=num_pareto, - use_posterior_mean=use_posterior_mean, - **mf_kwargs, - ) - val = qHVKG(X) - patch_f.assert_called_once() - cargs, ckwargs = patch_f.call_args - self.assertEqual(ckwargs["X"].shape, torch.Size([1, 1, 1])) + ) as patch_f, mock.patch( + NO, new_callable=mock.PropertyMock + ) as mock_num_outputs, warnings.catch_warnings( + record=True + ) as ws: + mock_num_outputs.return_value = 3 + qHVKG = acqf_class( + model=model, + num_fantasies=n_f, + objective=objective, + ref_point=ref_point, + num_pareto=num_pareto, + use_posterior_mean=use_posterior_mean, + **mf_kwargs, + ) + val = qHVKG(X) + patch_f.assert_called_once() + cargs, ckwargs = patch_f.call_args + self.assertEqual(ckwargs["X"].shape, torch.Size([1, 1, 1])) + self.assertFalse(any(w.category is NumericsWarning for w in ws)) Ys = mean if use_posterior_mean else samples objs = objective(Ys.squeeze(1)).view(-1, num_pareto, num_objectives) if num_objectives == 2: