From 4f12f5e2de1e1c72db7b36da13fcce45d48007f9 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 23 Jan 2023 14:05:27 -0800 Subject: [PATCH] Do not filter out BoTorchWarnings by default (#1630) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1630 Updates `debug` flag to be an opt-in only. Previously this was filtering out all BoTorchWarnings, which significantly limited usage of these warnings. Reviewed By: esantorella Differential Revision: D42017622 fbshipit-source-id: c03115ea404d166de36be53af0aa07a6ebfd7ceb --- botorch/optim/utils/timeout.py | 3 -- botorch/settings.py | 23 +++------------ botorch/utils/testing.py | 11 ++++++++ test/models/utils/test_assorted.py | 45 +++++++++++++----------------- test/test_settings.py | 24 +++++++++------- 5 files changed, 49 insertions(+), 57 deletions(-) diff --git a/botorch/optim/utils/timeout.py b/botorch/optim/utils/timeout.py index d0f7bd7f07..474d0d52c4 100644 --- a/botorch/optim/utils/timeout.py +++ b/botorch/optim/utils/timeout.py @@ -7,12 +7,10 @@ from __future__ import annotations import time -import warnings from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np from botorch.exceptions.errors import OptimizationTimeoutError -from botorch.exceptions.warnings import OptimizationWarning from scipy import optimize @@ -95,7 +93,6 @@ def wrapped_callback(xk: np.ndarray) -> None: ) except OptimizationTimeoutError as e: msg = f"Optimization timed out after {e.runtime} seconds." - warnings.warn(msg, OptimizationWarning) current_fun, *_ = fun(e.current_x, *args) return optimize.OptimizeResult( diff --git a/botorch/settings.py b/botorch/settings.py index 8261237060..d838f64f83 100644 --- a/botorch/settings.py +++ b/botorch/settings.py @@ -10,10 +10,6 @@ from __future__ import annotations -import typing # noqa F401 -import warnings - -from botorch.exceptions import BotorchWarning from botorch.logging import LOG_LEVEL_DEFAULT, logger @@ -55,30 +51,19 @@ class propagate_grads(_Flag): _state: bool = False -def suppress_botorch_warnings(suppress: bool) -> None: - r"""Set botorch warning filter. - - Args: - state: A boolean indicating whether warnings should be prints - """ - warnings.simplefilter("ignore" if suppress else "default", BotorchWarning) - - class debug(_Flag): - r"""Flag for printing verbose BotorchWarnings. + r"""Flag for printing verbose warnings. - When set to `True`, verbose `BotorchWarning`s will be printed for debuggability. - Warnings that are not subclasses of `BotorchWarning` will not be affected by - this context_manager. + To make sure a warning is only raised in debug mode: + >>> if debug.on(): + >>> warnings.warn() """ _state: bool = False - suppress_botorch_warnings(suppress=not _state) @classmethod def _set_state(cls, state: bool) -> None: cls._state = state - suppress_botorch_warnings(suppress=not cls._state) class validate_input_scaling(_Flag): diff --git a/botorch/utils/testing.py b/botorch/utils/testing.py index c71129b483..0754378f8c 100644 --- a/botorch/utils/testing.py +++ b/botorch/utils/testing.py @@ -15,6 +15,7 @@ import torch from botorch import settings from botorch.acquisition.objective import PosteriorTransform +from botorch.exceptions.warnings import BotorchTensorDimensionWarning, InputDataWarning from botorch.models.model import FantasizeMixin, Model from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.posteriors.posterior import Posterior @@ -50,6 +51,16 @@ def setUp(self): message="The model inputs are of type", category=UserWarning, ) + warnings.filterwarnings( + "ignore", + message="Non-strict enforcement of botorch tensor conventions.", + category=BotorchTensorDimensionWarning, + ) + warnings.filterwarnings( + "ignore", + message="Input data is not standardized.", + category=InputDataWarning, + ) def assertAllClose( self, diff --git a/test/models/utils/test_assorted.py b/test/models/utils/test_assorted.py index 49a5682264..493e360f93 100644 --- a/test/models/utils/test_assorted.py +++ b/test/models/utils/test_assorted.py @@ -134,34 +134,29 @@ def test_check_min_max_scaling(self): ) def test_check_standardization(self): + # Ensure that it is not filtered out. + warnings.filterwarnings("always", category=InputDataWarning) Y = torch.randn(3, 4, 2) # check standardized input Yst = (Y - Y.mean(dim=-2, keepdim=True)) / Y.std(dim=-2, keepdim=True) - with settings.debug(True): - with warnings.catch_warnings(record=True) as ws: - check_standardization(Y=Yst) - self.assertFalse( - any(issubclass(w.category, InputDataWarning) for w in ws) - ) - check_standardization(Y=Yst, raise_on_fail=True) - # check nonzero mean - with warnings.catch_warnings(record=True) as ws: - check_standardization(Y=Yst + 1) - self.assertTrue( - any(issubclass(w.category, InputDataWarning) for w in ws) - ) - self.assertTrue(any("not standardized" in str(w.message) for w in ws)) - with self.assertRaises(InputDataError): - check_standardization(Y=Yst + 1, raise_on_fail=True) - # check non-unit variance - with warnings.catch_warnings(record=True) as ws: - check_standardization(Y=Yst * 2) - self.assertTrue( - any(issubclass(w.category, InputDataWarning) for w in ws) - ) - self.assertTrue(any("not standardized" in str(w.message) for w in ws)) - with self.assertRaises(InputDataError): - check_standardization(Y=Yst * 2, raise_on_fail=True) + with warnings.catch_warnings(record=True) as ws: + check_standardization(Y=Yst) + self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws)) + check_standardization(Y=Yst, raise_on_fail=True) + # check nonzero mean + with warnings.catch_warnings(record=True) as ws: + check_standardization(Y=Yst + 1) + self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws)) + self.assertTrue(any("not standardized" in str(w.message) for w in ws)) + with self.assertRaises(InputDataError): + check_standardization(Y=Yst + 1, raise_on_fail=True) + # check non-unit variance + with warnings.catch_warnings(record=True) as ws: + check_standardization(Y=Yst * 2) + self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws)) + self.assertTrue(any("not standardized" in str(w.message) for w in ws)) + with self.assertRaises(InputDataError): + check_standardization(Y=Yst * 2, raise_on_fail=True) def test_validate_input_scaling(self): train_X = 2 + torch.rand(3, 4, 3) diff --git a/test/test_settings.py b/test/test_settings.py index 31b72d658e..69668ae76e 100644 --- a/test/test_settings.py +++ b/test/test_settings.py @@ -25,28 +25,32 @@ def test_flags(self): self.assertTrue(flag.off()) def test_debug(self): - # turn on BotorchWarning + # Turn on debug. settings.debug._set_state(True) - # check that warnings are suppressed + # Check that debug warnings are suppressed when it is turned off. with settings.debug(False): with warnings.catch_warnings(record=True) as ws: - warnings.warn("test", BotorchWarning) + if settings.debug.on(): + warnings.warn("test", BotorchWarning) self.assertEqual(len(ws), 0) - # check that warnings are not suppressed outside of context manager + # Check that warnings are not suppressed outside of context manager. with warnings.catch_warnings(record=True) as ws: - warnings.warn("test", BotorchWarning) + if settings.debug.on(): + warnings.warn("test", BotorchWarning) self.assertEqual(len(ws), 1) - # turn off BotorchWarnings + # Turn off debug. settings.debug._set_state(False) - # check that warnings are not suppressed + # Check that warnings are not suppressed within debug. with settings.debug(True): with warnings.catch_warnings(record=True) as ws: - warnings.warn("test", BotorchWarning) + if settings.debug.on(): + warnings.warn("test", BotorchWarning) self.assertEqual(len(ws), 1) - # check that warnings are suppressed outside of context manager + # Check that warnings are suppressed outside of context manager. with warnings.catch_warnings(record=True) as ws: - warnings.warn("test", BotorchWarning) + if settings.debug.on(): + warnings.warn("test", BotorchWarning) self.assertEqual(len(ws), 0)