diff --git a/.gitignore b/.gitignore index db35ac44c6207..fff549a718794 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,4 @@ mlruns/ *.ckpt pytorch\ lightning test-reports/ +wandb diff --git a/CHANGELOG.md b/CHANGELOG.md index cc7ec0401d93c..9a534c6bfaf40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,28 +11,44 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `dirpath` and `filename` parameter in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213)) + - Added plugins docs and DDPPlugin to customize ddp across all accelerators([#4258](https://github.com/PyTorchLightning/pytorch-lightning/pull/4285)) + - Added `strict` option to the scheduler dictionary ([#3586](https://github.com/PyTorchLightning/pytorch-lightning/pull/3586)) + - Added `fsspec` support for profilers ([#4162](https://github.com/PyTorchLightning/pytorch-lightning/pull/4162)) + ### Changed + - Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587)) + - Allow changing the logged step value in `validation_step` ([#4130](https://github.com/PyTorchLightning/pytorch-lightning/pull/4130)) + + - Allow setting `replace_sampler_ddp=True` with a distributed sampler already added ([#4273](https://github.com/PyTorchLightning/pytorch-lightning/pull/4273)) + +- Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320)) + + ### Deprecated + - Deprecated `filepath` in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213)) + - Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237)) + ### Removed + ### Fixed - Fixed setting device ids in DDP ([#4297](https://github.com/PyTorchLightning/pytorch-lightning/pull/4297)) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 8f72830027806..cf0b22d7d446f 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -168,6 +168,31 @@ def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: return params + @staticmethod + def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]: + """ + Sanitize callable params dict, e.g. ``{'a': } -> {'a': 'function_****'}``. + + Args: + params: Dictionary containing the hyperparameters + + Returns: + dictionary with all callables sanitized + """ + def _sanitize_callable(val): + # Give them one chance to return a value. Don't go rabbit hole of recursive call + if isinstance(val, Callable): + try: + _val = val() + if isinstance(_val, Callable): + return val.__name__ + return _val + except Exception: + return val.__name__ + return val + + return {key: _sanitize_callable(val) for key, val in params.items()} + @staticmethod def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]: """ diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index ca2b04d86aea8..e6ce264d597bf 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -135,6 +135,7 @@ def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100): def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = self._convert_params(params) params = self._flatten_dict(params) + params = self._sanitize_callable_params(params) self.experiment.config.update(params, allow_val_change=True) @rank_zero_only diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index e87d1dff126d9..6682cfdc8830a 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -14,6 +14,8 @@ import os import pickle from unittest import mock +from argparse import ArgumentParser +import types from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger @@ -109,3 +111,30 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints') assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + + +def test_wandb_sanitize_callable_params(tmpdir): + """ + Callback function are not serializiable. Therefore, we get them a chance to return + something and if the returned type is not accepted, return None. + """ + opt = "--max_epochs 1".split(" ") + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parent_parser=parser) + params = parser.parse_args(opt) + + def return_something(): + return "something" + params.something = return_something + + def wrapper_something(): + return return_something + params.wrapper_something = wrapper_something + + assert isinstance(params.gpus, types.FunctionType) + params = WandbLogger._convert_params(params) + params = WandbLogger._flatten_dict(params) + params = WandbLogger._sanitize_callable_params(params) + assert params["gpus"] == '_gpus_arg_default' + assert params["something"] == "something" + assert params["wrapper_something"] == "wrapper_something"