Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG - Wandb: Sanitize callable. #4320

Merged
merged 12 commits into from
Oct 26, 2020
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,4 @@ mlruns/
*.ckpt
pytorch\ lightning
test-reports/
wandb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how it happens that this folder is created?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wandb logger always makes this folder, we can't avoid this. I have it too. I think it's ok to add it to gitignore.

2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `hparams` assign in init ([#4189](https://github.com/PyTorchLightning/pytorch-lightning/pull/4189))
- Fixed overwrite check for model hooks ([#4010](https://github.com/PyTorchLightning/pytorch-lightning/pull/4010))
- Fixed santized parameters for `wandb_logger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320))
Borda marked this conversation as resolved.
Show resolved Hide resolved



## [1.0.2] - 2020-10-15
Expand Down
25 changes: 25 additions & 0 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,31 @@ def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:

return params

@staticmethod
def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitize callable params dict, e.g. ``{'a': <function_**** at 0x*****> -> {'a': 'function_****'}``.

Args:
params: Dictionary containing the hyperparameters

Returns:
dict.
"""
Borda marked this conversation as resolved.
Show resolved Hide resolved
def _sanitize_callable(val):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# Give them one chance to return a value. Don't go rabbit hole of recursive call
if isinstance(val, Callable):
try:
_val = val()
if isinstance(_val, Callable):
return val.__name__
return _val
except Exception:
return val.__name__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #4380, not all objects have __name__
maybe we need __class__.__name__

return val

return {key: _sanitize_callable(val) for key, val in params.items()}

@staticmethod
def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]:
"""
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
params = self._sanitize_callable_params(params)
self.experiment.config.update(params, allow_val_change=True)

@rank_zero_only
Expand Down
29 changes: 29 additions & 0 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import os
import pickle
from unittest import mock
from argparse import ArgumentParser
import types

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
Expand Down Expand Up @@ -109,3 +111,30 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):

assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}


def test_wandb_sanitize_callable_params(tmpdir):
"""
Callback function are not serializiable. Therefore, we get them a chance to return
something and if the returned type is not accepted, return None.
"""
opt = "--max_epochs 1".split(" ")
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parent_parser=parser)
params = parser.parse_args(opt)

def return_something():
return "something"
params.something = return_something

def wrapper_something():
return return_something
params.wrapper_something = wrapper_something

assert isinstance(params.gpus, types.FunctionType)
params = WandbLogger._convert_params(params)
params = WandbLogger._flatten_dict(params)
params = WandbLogger._sanitize_callable_params(params)
assert params["gpus"] == '_gpus_arg_default'
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"