Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not filter out BoTorchWarnings by default #1630

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions botorch/optim/utils/timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
from __future__ import annotations

import time
import warnings
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np
from botorch.exceptions.errors import OptimizationTimeoutError
from botorch.exceptions.warnings import OptimizationWarning
from scipy import optimize


Expand Down Expand Up @@ -95,7 +93,6 @@ def wrapped_callback(xk: np.ndarray) -> None:
)
except OptimizationTimeoutError as e:
msg = f"Optimization timed out after {e.runtime} seconds."
warnings.warn(msg, OptimizationWarning)
current_fun, *_ = fun(e.current_x, *args)

return optimize.OptimizeResult(
Expand Down
23 changes: 4 additions & 19 deletions botorch/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@

from __future__ import annotations

import typing # noqa F401
import warnings

from botorch.exceptions import BotorchWarning
from botorch.logging import LOG_LEVEL_DEFAULT, logger


Expand Down Expand Up @@ -55,30 +51,19 @@ class propagate_grads(_Flag):
_state: bool = False


def suppress_botorch_warnings(suppress: bool) -> None:
r"""Set botorch warning filter.

Args:
state: A boolean indicating whether warnings should be prints
"""
warnings.simplefilter("ignore" if suppress else "default", BotorchWarning)


class debug(_Flag):
r"""Flag for printing verbose BotorchWarnings.
r"""Flag for printing verbose warnings.

When set to `True`, verbose `BotorchWarning`s will be printed for debuggability.
Warnings that are not subclasses of `BotorchWarning` will not be affected by
this context_manager.
To make sure a warning is only raised in debug mode:
>>> if debug.on():
>>> warnings.warn(<some warning>)
"""

_state: bool = False
suppress_botorch_warnings(suppress=not _state)

@classmethod
def _set_state(cls, state: bool) -> None:
cls._state = state
suppress_botorch_warnings(suppress=not cls._state)


class validate_input_scaling(_Flag):
Expand Down
11 changes: 11 additions & 0 deletions botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from botorch import settings
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.warnings import BotorchTensorDimensionWarning, InputDataWarning
from botorch.models.model import FantasizeMixin, Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
Expand Down Expand Up @@ -50,6 +51,16 @@ def setUp(self):
message="The model inputs are of type",
category=UserWarning,
)
warnings.filterwarnings(
"ignore",
message="Non-strict enforcement of botorch tensor conventions.",
category=BotorchTensorDimensionWarning,
)
warnings.filterwarnings(
"ignore",
message="Input data is not standardized.",
category=InputDataWarning,
)

def assertAllClose(
self,
Expand Down
45 changes: 20 additions & 25 deletions test/models/utils/test_assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,34 +134,29 @@ def test_check_min_max_scaling(self):
)

def test_check_standardization(self):
# Ensure that it is not filtered out.
warnings.filterwarnings("always", category=InputDataWarning)
Y = torch.randn(3, 4, 2)
# check standardized input
Yst = (Y - Y.mean(dim=-2, keepdim=True)) / Y.std(dim=-2, keepdim=True)
with settings.debug(True):
with warnings.catch_warnings(record=True) as ws:
check_standardization(Y=Yst)
self.assertFalse(
any(issubclass(w.category, InputDataWarning) for w in ws)
)
check_standardization(Y=Yst, raise_on_fail=True)
# check nonzero mean
with warnings.catch_warnings(record=True) as ws:
check_standardization(Y=Yst + 1)
self.assertTrue(
any(issubclass(w.category, InputDataWarning) for w in ws)
)
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
with self.assertRaises(InputDataError):
check_standardization(Y=Yst + 1, raise_on_fail=True)
# check non-unit variance
with warnings.catch_warnings(record=True) as ws:
check_standardization(Y=Yst * 2)
self.assertTrue(
any(issubclass(w.category, InputDataWarning) for w in ws)
)
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
with self.assertRaises(InputDataError):
check_standardization(Y=Yst * 2, raise_on_fail=True)
with warnings.catch_warnings(record=True) as ws:
check_standardization(Y=Yst)
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
check_standardization(Y=Yst, raise_on_fail=True)
# check nonzero mean
with warnings.catch_warnings(record=True) as ws:
check_standardization(Y=Yst + 1)
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
with self.assertRaises(InputDataError):
check_standardization(Y=Yst + 1, raise_on_fail=True)
# check non-unit variance
with warnings.catch_warnings(record=True) as ws:
check_standardization(Y=Yst * 2)
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
with self.assertRaises(InputDataError):
check_standardization(Y=Yst * 2, raise_on_fail=True)

def test_validate_input_scaling(self):
train_X = 2 + torch.rand(3, 4, 3)
Expand Down
24 changes: 14 additions & 10 deletions test/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,32 @@ def test_flags(self):
self.assertTrue(flag.off())

def test_debug(self):
# turn on BotorchWarning
# Turn on debug.
settings.debug._set_state(True)
# check that warnings are suppressed
# Check that debug warnings are suppressed when it is turned off.
with settings.debug(False):
with warnings.catch_warnings(record=True) as ws:
warnings.warn("test", BotorchWarning)
if settings.debug.on():
warnings.warn("test", BotorchWarning)
self.assertEqual(len(ws), 0)
# check that warnings are not suppressed outside of context manager
# Check that warnings are not suppressed outside of context manager.
with warnings.catch_warnings(record=True) as ws:
warnings.warn("test", BotorchWarning)
if settings.debug.on():
warnings.warn("test", BotorchWarning)
self.assertEqual(len(ws), 1)

# turn off BotorchWarnings
# Turn off debug.
settings.debug._set_state(False)
# check that warnings are not suppressed
# Check that warnings are not suppressed within debug.
with settings.debug(True):
with warnings.catch_warnings(record=True) as ws:
warnings.warn("test", BotorchWarning)
if settings.debug.on():
warnings.warn("test", BotorchWarning)
self.assertEqual(len(ws), 1)
# check that warnings are suppressed outside of context manager
# Check that warnings are suppressed outside of context manager.
with warnings.catch_warnings(record=True) as ws:
warnings.warn("test", BotorchWarning)
if settings.debug.on():
warnings.warn("test", BotorchWarning)
self.assertEqual(len(ws), 0)


Expand Down