Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mocking Loggers (part 3c, comet) #3859

Merged
merged 2 commits into from
Oct 5, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 58 additions & 37 deletions tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,30 @@
from tests.base import EvalModelTemplate


def test_comet_logger_online():
def _patch_comet_atexit(monkeypatch):
""" Prevent comet logger from trying to print at exit, since pytest's stdout/stderr redirection breaks it. """
import atexit
monkeypatch.setattr(atexit, "register", lambda _: None)


@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_logger_online(comet):
"""Test comet online with mocks."""
# Test api_key given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
logger = CometLogger(api_key='key', workspace='dummy-test', project_name='general')

_ = logger.experiment

comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')
comet_experiment.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')

# Test both given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
logger = CometLogger(save_dir='test', api_key='key', workspace='dummy-test', project_name='general')

_ = logger.experiment

comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')

# Test neither given
with pytest.raises(MisconfigurationException):
CometLogger(workspace='dummy-test', project_name='general')
comet_experiment.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')

# Test already exists
with patch('pytorch_lightning.loggers.comet.CometExistingExperiment') as comet_existing:
Expand All @@ -55,56 +58,72 @@ def test_comet_logger_online():
api.assert_called_once_with('rest')


def test_comet_logger_experiment_name():
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_logger_no_api_key_given(comet):
""" Test that CometLogger fails to initialize if both api key and save_dir are missing. """
with pytest.raises(MisconfigurationException):
comet.config.get_api_key.return_value = None
CometLogger(workspace='dummy-test', project_name='general')


@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_logger_experiment_name(comet):
"""Test that Comet Logger experiment name works correctly."""

api_key = "key"
experiment_name = "My Name"

# Test api_key given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
logger = CometLogger(api_key=api_key, experiment_name=experiment_name,)

assert logger._experiment is None

_ = logger.experiment

comet.assert_called_once_with(api_key=api_key, project_name=None)
comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)

comet().set_name.assert_called_once_with(experiment_name)
comet_experiment().set_name.assert_called_once_with(experiment_name)


def test_comet_logger_dirs_creation(tmpdir, monkeypatch):
@patch('pytorch_lightning.loggers.comet.CometOfflineExperiment')
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch):
""" Test that the logger creates the folders and files in the right place. """
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
import atexit

monkeypatch.setattr(atexit, 'register', lambda _: None)
_patch_comet_atexit(monkeypatch)
comet.config.get_api_key.return_value = None
comet.generate_guid.return_value = "4321"

logger = CometLogger(project_name='test', save_dir=tmpdir)
assert not os.listdir(tmpdir)
assert logger.mode == 'offline'
assert logger.save_dir == tmpdir
assert logger.name == 'test'
assert logger.version == "4321"

_ = logger.experiment
version = logger.version
assert set(os.listdir(tmpdir)) == {f'{logger.experiment.id}.zip'}

comet_experiment.assert_called_once_with(offline_directory=tmpdir, project_name='test')

# mock return values of experiment
logger.experiment.id = '1'
logger.experiment.project_name = 'test'

model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer.fit(model)

assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints')
assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}


def test_comet_name_default():
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_name_default(comet):
""" Test that CometLogger.name don't create an Experiment and returns a default value. """

api_key = "key"

with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key)

assert logger._experiment is None
Expand All @@ -114,13 +133,14 @@ def test_comet_name_default():
assert logger._experiment is None


def test_comet_name_project_name():
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_name_project_name(comet):
""" Test that CometLogger.name does not create an Experiment and returns project name if passed. """

api_key = "key"
project_name = "My Project Name"

with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key, project_name=project_name)

assert logger._experiment is None
Expand All @@ -130,13 +150,15 @@ def test_comet_name_project_name():
assert logger._experiment is None


def test_comet_version_without_experiment():
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_version_without_experiment(comet):
""" Test that CometLogger.version does not create an Experiment. """

api_key = "key"
experiment_name = "My Name"
comet.generate_guid.return_value = "1234"

with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key, experiment_name=experiment_name)

assert logger._experiment is None
Expand All @@ -152,17 +174,16 @@ def test_comet_version_without_experiment():

logger.reset_experiment()

second_version = logger.version
second_version = logger.version == "1234"
assert second_version is not None
assert second_version != first_version


def test_comet_epoch_logging(tmpdir, monkeypatch):
@patch("pytorch_lightning.loggers.comet.CometExperiment")
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch):
""" Test that CometLogger removes the epoch key from the metrics dict and passes it as argument. """
import atexit

monkeypatch.setattr(atexit, "register", lambda _: None)
with patch("pytorch_lightning.loggers.comet.CometOfflineExperiment.log_metrics") as log_metrics:
logger = CometLogger(project_name="test", save_dir=tmpdir)
logger.log_metrics({"test": 1, "epoch": 1}, step=123)
log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123)
_patch_comet_atexit(monkeypatch)
logger = CometLogger(project_name="test", save_dir=tmpdir)
logger.log_metrics({"test": 1, "epoch": 1}, step=123)
logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123)