Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility for custom naming when using multiple dataloaders #6274

Merged
merged 9 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))


- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))


### Changed

- Changed the order of `backward`, `step`, `zero_grad` to `zero_grad`, `backward`, `step` ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))
Expand Down
18 changes: 15 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def log(
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
):
"""
Log a key, value
Expand Down Expand Up @@ -259,7 +260,10 @@ def log(
enable_graph: if True, will not auto detach the graph
sync_dist: if True, reduces the metric across GPUs/TPUs
sync_dist_op: the op to sync across GPUs/TPUs
sync_dist_group: the ddp group
sync_dist_group: the ddp group to sync across
add_dataloader_idx: if True, appends the index of the current dataloader to
the name (when using multiple). If False, user needs to give unique names for
each dataloader to not mix values
"""
if self._results is not None:
# in any epoch end can't log step metrics (only epoch metric)
Expand Down Expand Up @@ -291,6 +295,9 @@ def log(

training_type_plugin = self.trainer.training_type_plugin

# Determine if dataloader index should be added
dataloader_idx = self._current_dataloader_idx if add_dataloader_idx else None

self._results.log(
name,
value,
Expand All @@ -306,7 +313,7 @@ def log(
sync_dist_op,
sync_dist_group,
training_type_plugin.reduce,
self._current_dataloader_idx,
dataloader_idx,
self.device,
)

Expand All @@ -324,6 +331,7 @@ def log_dict(
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
):
"""
Log a dictonary of values at once
Expand All @@ -345,7 +353,10 @@ def log_dict(
enable_graph: if True, will not auto detach the graph
sync_dist: if True, reduces the metric across GPUs/TPUs
sync_dist_op: the op to sync across GPUs/TPUs
sync_dist_group: the ddp group:
sync_dist_group: the ddp group sync across
add_dataloader_idx: if True, appends the index of the current dataloader to
the name (when using multiple). If False, user needs to give unique names for
each dataloader to not mix values
"""
for k, v in dictionary.items():
self.log(
Expand All @@ -362,6 +373,7 @@ def log_dict(
sync_dist_op=sync_dist_op,
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx,
add_dataloader_idx=add_dataloader_idx
)

def write_prediction(
Expand Down
38 changes: 38 additions & 0 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,41 @@ def training_step(self, *args, **kwargs):
)
with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"):
trainer.fit(model)


@pytest.mark.parametrize("add_dataloader_idx", [False, True])
def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx):
""" test that auto_add_dataloader_idx argument works """

class TestModel(BoringModel):
def val_dataloader(self):
dl = super().val_dataloader()
return [dl, dl]

def validation_step(self, *args, **kwargs):
output = super().validation_step(*args[:-1], **kwargs)
if add_dataloader_idx:
name = "val_loss"
else:
name = f"val_loss_custom_naming_{args[-1]}"

self.log(name, output["x"], add_dataloader_idx=add_dataloader_idx)
return output

model = TestModel()
model.validation_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
max_steps=5
)
trainer.fit(model)
logged = trainer.logged_metrics

# Check that the correct keys exist
if add_dataloader_idx:
assert 'val_loss/dataloader_idx_0' in logged
assert 'val_loss/dataloader_idx_1' in logged
else:
assert 'val_loss_custom_naming_0' in logged
assert 'val_loss_custom_naming_1' in logged