Skip to content

Commit

Permalink
Do not filter out BoTorchWarnings by default (pytorch#1630)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1630

Updates `debug` flag to be an opt-in only. Previously this was filtering out all BoTorchWarnings, which significantly limited usage of these warnings.

Differential Revision: D42017622

fbshipit-source-id: 9d46140150b896a40eb058310677aacd7005e6da
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jan 14, 2023
1 parent 3596b12 commit 413dfa0
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
3 changes: 0 additions & 3 deletions botorch/optim/utils/timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
from __future__ import annotations

import time
import warnings
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np
from botorch.exceptions.errors import OptimizationTimeoutError
from botorch.exceptions.warnings import OptimizationWarning
from scipy import optimize


Expand Down Expand Up @@ -95,7 +93,6 @@ def wrapped_callback(xk: np.ndarray) -> None:
)
except OptimizationTimeoutError as e:
msg = f"Optimization timed out after {e.runtime} seconds."
warnings.warn(msg, OptimizationWarning)
current_fun, *_ = fun(e.current_x, *args)

return optimize.OptimizeResult(
Expand Down
12 changes: 6 additions & 6 deletions botorch/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,20 @@ def suppress_botorch_warnings(suppress: bool) -> None:


class debug(_Flag):
r"""Flag for printing verbose BotorchWarnings.
r"""Flag for printing verbose warnings.
When set to `True`, verbose `BotorchWarning`s will be printed for debuggability.
Warnings that are not subclasses of `BotorchWarning` will not be affected by
this context_manager.
To make sure a warning is only raised in debug mode:
```
if debug.on():
warnings.warn(<some warning>)
```
"""

_state: bool = False
suppress_botorch_warnings(suppress=not _state)

@classmethod
def _set_state(cls, state: bool) -> None:
cls._state = state
suppress_botorch_warnings(suppress=not cls._state)


class validate_input_scaling(_Flag):
Expand Down
11 changes: 11 additions & 0 deletions botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from botorch import settings
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.warnings import BotorchTensorDimensionWarning, InputDataWarning
from botorch.models.model import FantasizeMixin, Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
Expand Down Expand Up @@ -50,6 +51,16 @@ def setUp(self):
message="The model inputs are of type",
category=UserWarning,
)
warnings.filterwarnings(
"ignore",
message="Non-strict enforcement of botorch tensor conventions.",
category=BotorchTensorDimensionWarning,
)
warnings.filterwarnings(
"ignore",
message="Input data is not standardized.",
category=InputDataWarning,
)

def assertAllClose(
self,
Expand Down
24 changes: 14 additions & 10 deletions test/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,32 @@ def test_flags(self):
self.assertTrue(flag.off())

def test_debug(self):
# turn on BotorchWarning
# Turn on debug.
settings.debug._set_state(True)
# check that warnings are suppressed
# Check that debug warnings are suppressed when it is turned off.
with settings.debug(False):
with warnings.catch_warnings(record=True) as ws:
warnings.warn("test", BotorchWarning)
if settings.debug.on():
warnings.warn("test", BotorchWarning)
self.assertEqual(len(ws), 0)
# check that warnings are not suppressed outside of context manager
# Check that warnings are not suppressed outside of context manager.
with warnings.catch_warnings(record=True) as ws:
warnings.warn("test", BotorchWarning)
if settings.debug.on():
warnings.warn("test", BotorchWarning)
self.assertEqual(len(ws), 1)

# turn off BotorchWarnings
# Turn off debug.
settings.debug._set_state(False)
# check that warnings are not suppressed
# Check that warnings are not suppressed within debug.
with settings.debug(True):
with warnings.catch_warnings(record=True) as ws:
warnings.warn("test", BotorchWarning)
if settings.debug.on():
warnings.warn("test", BotorchWarning)
self.assertEqual(len(ws), 1)
# check that warnings are suppressed outside of context manager
# Check that warnings are suppressed outside of context manager.
with warnings.catch_warnings(record=True) as ws:
warnings.warn("test", BotorchWarning)
if settings.debug.on():
warnings.warn("test", BotorchWarning)
self.assertEqual(len(ws), 0)


Expand Down

0 comments on commit 413dfa0

Please sign in to comment.