diff --git a/CHANGELOG.md b/CHANGELOG.md index 94675a20111c0..968d384af717a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed Horovod distributed backend to set the `root_gpu` property ([#1669](https://github.com/PyTorchLightning/pytorch-lightning/pull/1669)) +- Fixed wandb logger `global_step` affects other loggers ([#1492](https://github.com/PyTorchLightning/pytorch-lightning/issues/1485)) ## [0.7.5] - 2020-04-27 diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 39891c447bba2..857d661fdb5b3 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -125,7 +125,7 @@ def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = N """ agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step) - if metrics_to_log is not None: + if metrics_to_log: self.log_metrics(metrics=metrics_to_log, step=agg_step) @abstractmethod diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index c348644141fca..0d5ff9855a40d 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -119,9 +119,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: - if step is not None: - metrics['global_step'] = step - self.experiment.log(metrics) + self.experiment.log(metrics, step=step) @property def name(self) -> str: diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index bb2739f95f6e0..cb9aad20315e9 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -16,11 +16,11 @@ def test_wandb_logger(wandb): logger = WandbLogger(anonymous=True, offline=True) logger.log_metrics({'acc': 1.0}) - wandb.init().log.assert_called_once_with({'acc': 1.0}) + wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None) wandb.init().log.reset_mock() logger.log_metrics({'acc': 1.0}, step=3) - wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0}) + wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3) logger.log_hyperparams({'test': None}) wandb.init().config.update.assert_called_once_with({'test': None}, allow_val_change=True)