Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Oct 5, 2020
1 parent e0f8505 commit 29a2aa6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 25 deletions.
22 changes: 5 additions & 17 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from pytorch_lightning.loggers.base import DummyExperiment
from tests.base import EvalModelTemplate
from tests.loggers.test_comet import _patch_comet_atexit


def _get_logger_args(logger_class, save_dir):
Expand All @@ -45,11 +46,7 @@ def _get_logger_args(logger_class, save_dir):
def test_loggers_fit_test(wandb, neptune, tmpdir, monkeypatch, logger_class):
"""Verify that basic functionality of all loggers."""
os.environ['PL_DEV_DEBUG'] = '0'

if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
monkeypatch.setattr(atexit, 'register', lambda _: None)
_patch_comet_atexit(monkeypatch)

model = EvalModelTemplate()

Expand Down Expand Up @@ -110,10 +107,7 @@ def log_metrics(self, metrics, step):
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_loggers_save_dir_and_weights_save_path(wandb, tmpdir, monkeypatch, logger_class):
""" Test the combinations of save_dir, weights_save_path and default_root_dir. """
if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
monkeypatch.setattr(atexit, 'register', lambda _: None)
_patch_comet_atexit(monkeypatch)

class TestLogger(logger_class):
# for this test it does not matter what these attributes are
Expand Down Expand Up @@ -173,10 +167,7 @@ def name(self):
@mock.patch('pytorch_lightning.loggers.neptune.neptune')
def test_loggers_pickle(neptune, tmpdir, monkeypatch, logger_class):
"""Verify that pickling trainer with logger works."""
if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
monkeypatch.setattr(atexit, 'register', lambda _: None)
_patch_comet_atexit(monkeypatch)

logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)
Expand Down Expand Up @@ -250,10 +241,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_
@mock.patch('pytorch_lightning.loggers.neptune.neptune')
def test_logger_created_on_rank_zero_only(neptune, tmpdir, monkeypatch, logger_class):
""" Test that loggers get replaced by dummy loggers on global rank > 0"""
if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
monkeypatch.setattr(atexit, 'register', lambda _: None)
_patch_comet_atexit(monkeypatch)

logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)
Expand Down
16 changes: 8 additions & 8 deletions tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from tests.base import EvalModelTemplate


def _patch_comet_atexit(monkeypatch):
""" Prevent comet logger from trying to print at exit, since pytest's stdout/stderr redirection breaks it. """
import atexit
monkeypatch.setattr(atexit, "register", lambda _: None)


def test_comet_logger_online():
"""Test comet online with mocks."""
# Test api_key given
Expand Down Expand Up @@ -76,11 +82,7 @@ def test_comet_logger_experiment_name():

def test_comet_logger_dirs_creation(tmpdir, monkeypatch):
""" Test that the logger creates the folders and files in the right place. """
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
import atexit

monkeypatch.setattr(atexit, 'register', lambda _: None)
_patch_comet_atexit(monkeypatch)

logger = CometLogger(project_name='test', save_dir=tmpdir)
assert not os.listdir(tmpdir)
Expand Down Expand Up @@ -159,9 +161,7 @@ def test_comet_version_without_experiment():

def test_comet_epoch_logging(tmpdir, monkeypatch):
""" Test that CometLogger removes the epoch key from the metrics dict and passes it as argument. """
import atexit

monkeypatch.setattr(atexit, "register", lambda _: None)
_patch_comet_atexit(monkeypatch)
with patch("pytorch_lightning.loggers.comet.CometOfflineExperiment.log_metrics") as log_metrics:
logger = CometLogger(project_name="test", save_dir=tmpdir)
logger.log_metrics({"test": 1, "epoch": 1}, step=123)
Expand Down

0 comments on commit 29a2aa6

Please sign in to comment.