From ec9142be9edf199d00a680b24388d54d2e05e058 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 24 Feb 2021 16:09:29 +0100 Subject: [PATCH 1/8] add prefix --- docs/source/extensions/metrics.rst | 15 ++++++++------- pytorch_lightning/metrics/metric.py | 20 +++++++++++++++++--- tests/metrics/test_metric_lightning.py | 14 +++++++------- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/docs/source/extensions/metrics.rst b/docs/source/extensions/metrics.rst index 6a64c42ec2753..27f5de1b3ead1 100644 --- a/docs/source/extensions/metrics.rst +++ b/docs/source/extensions/metrics.rst @@ -384,23 +384,24 @@ inside your LightningModule def __init__(self): ... - metrics = pl.metrics.MetricCollection(...) - self.train_metrics = metrics.clone() - self.valid_metrics = metrics.clone() + self.train_metrics = pl.metrics.MetricCollection(Accuracy(), Precision(), Recall(), prefix='train_') + self.valid_metrics = pl.metrics.MetricCollection(Accuracy(), Precision(), Recall(), prefix='val_') def training_step(self, batch, batch_idx): logits = self(x) ... - self.train_metrics(logits, y) + output = self.train_metrics(logits, y) # use log_dict instead of log - self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train') + # metrics are logged with keys: train_Accuracy, train_Precision and train_Recall + self.log_dict(output) def validation_step(self, batch, batch_idx): logits = self(x) ... - self.valid_metrics(logits, y) + output = self.valid_metrics(logits, y) # use log_dict instead of log - self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val') + # metrics are logged with keys: val_Accuracy, val_Precision and val_Recall + self.log_dict(output) .. note:: diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index ab198356f7279..6aa2764c59363 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -528,6 +528,8 @@ class MetricCollection(nn.ModuleDict): dict as key for output dict. Use this format if you want to chain together multiple of the same metric with different parameters. + prefix: a string to append in front of the keys of the output dict + Example (input as list): >>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall @@ -548,7 +550,10 @@ class MetricCollection(nn.ModuleDict): """ - def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): + def __init__( + self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]], + prefix: Optional[str] = None + ): super().__init__() if isinstance(metrics, dict): # Check all values are metrics @@ -573,13 +578,19 @@ def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric] else: raise ValueError("Unknown input to MetricCollection.") + if prefix is not None: + if isinstance(prefix, str): + self.prefix = prefix + else: + raise ValueError('Expected input `prefix` to be a string') + def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 """ Iteratively call forward for each metric. Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs) will be filtered based on the signature of the individual metric. """ - return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} + return {self._set_prefix(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} def update(self, *args, **kwargs): # pylint: disable=E0202 """ @@ -592,7 +603,7 @@ def update(self, *args, **kwargs): # pylint: disable=E0202 m.update(*args, **m_kwargs) def compute(self) -> Dict[str, Any]: - return {k: m.compute() for k, m in self.items()} + return {self._set_prefix(k): m.compute() for k, m in self.items()} def reset(self): """ Iteratively call reset for each metric """ @@ -609,3 +620,6 @@ def persistent(self, mode: bool = True): """ for _, m in self.items(): m.persistent(mode) + + def _set_prefix(self, k): + return k if self.prefix is None else self.prefix+k \ No newline at end of file diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 895305fa9da7e..b19d2908e30b7 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -149,26 +149,26 @@ def training_step(self, batch, batch_idx): def test_metric_collection_lightning_log(tmpdir): - + """ test that logging in lightning works with the MetricCollection class """ class TestModel(BoringModel): def __init__(self): super().__init__() - self.metric = MetricCollection([SumMetric(), DiffMetric()]) + self.metric = MetricCollection([SumMetric(), DiffMetric()], prefix='train_') self.sum = 0.0 self.diff = 0.0 def training_step(self, batch, batch_idx): x = batch - metric_vals = self.metric(x.sum()) + self.metric(x.sum()) self.sum += x.sum() self.diff -= x.sum() - self.log_dict({f'{k}_step': v for k, v in metric_vals.items()}) + self.log_dict(self.metric) return self.step(x) def training_epoch_end(self, outputs): metric_vals = self.metric.compute() - self.log_dict({f'{k}_epoch': v for k, v in metric_vals.items()}) + self.log_dict(metric_vals) model = TestModel() model.val_dataloader = None @@ -184,5 +184,5 @@ def training_epoch_end(self, outputs): trainer.fit(model) logged = trainer.logged_metrics - assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum) - assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff) + assert torch.allclose(torch.tensor(logged["train_SumMetric"]), model.sum) + assert torch.allclose(torch.tensor(logged["train_DiffMetric"]), model.diff) From 9da7efcf712c4e94205cf8210df343b1771421fd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 1 Mar 2021 20:33:06 +0100 Subject: [PATCH 2/8] add flag --- pytorch_lightning/core/lightning.py | 17 +++++++-- .../trainer/logging_/test_logger_connector.py | 35 +++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c4d63cff4637b..3b06bd6c39a45 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -226,6 +226,7 @@ def log( sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, + auto_add_dataloader_idx: bool = True, ): """ Log a key, value @@ -260,7 +261,10 @@ def log( enable_graph: if True, will not auto detach the graph sync_dist: if True, reduces the metric across GPUs/TPUs sync_dist_op: the op to sync across GPUs/TPUs - sync_dist_group: the ddp group + sync_dist_group: the ddp group to sync across + auto_add_dataloader_idx: if True, appends the index of the current dataloader to + the name (when using multiple). If False, user needs to give unique names for + each dataloader to not mix values """ if self._results is not None: # in any epoch end can't log step metrics (only epoch metric) @@ -292,6 +296,8 @@ def log( training_type_plugin = self.trainer.training_type_plugin + dataloader_idx = self._current_dataloader_idx if auto_add_dataloader_idx else None + self._results.log( name, value, @@ -307,7 +313,7 @@ def log( sync_dist_op, sync_dist_group, training_type_plugin.reduce, - self._current_dataloader_idx, + dataloader_idx, self.device, ) @@ -325,6 +331,7 @@ def log_dict( sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, + auto_add_dataloader_idx: bool = True, ): """ Log a dictonary of values at once @@ -346,7 +353,10 @@ def log_dict( enable_graph: if True, will not auto detach the graph sync_dist: if True, reduces the metric across GPUs/TPUs sync_dist_op: the op to sync across GPUs/TPUs - sync_dist_group: the ddp group: + sync_dist_group: the ddp group sync across + auto_add_dataloader_idx: if True, appends the index of the current dataloader to + the name (when using multiple). If False, user needs to give unique names for + each dataloader to not mix values """ for k, v in dictionary.items(): self.log( @@ -363,6 +373,7 @@ def log_dict( sync_dist_op=sync_dist_op, tbptt_pad_token=tbptt_pad_token, tbptt_reduce_fx=tbptt_reduce_fx, + auto_add_dataloader_idx=auto_add_dataloader_idx ) def write_prediction( diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 92eb2c76a8c6b..65fecb16223ab 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -28,6 +28,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import EvalModelTemplate from tests.helpers.boring_model import BoringModel, RandomDataset @@ -470,3 +471,37 @@ def training_step(self, *args, **kwargs): ) with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"): trainer.fit(model) + + +@pytest.mark.parametrize("auto_add_dataloader_idx", [False, True]) +def test_auto_add_dataloader_idx(tmpdir, auto_add_dataloader_idx): + """ test that auto_add_dataloader_idx argument works """ + + class TestModel(EvalModelTemplate): + + def validation_step(self, *args, **kwargs): + output = super().validation_step(*args, **kwargs) + if auto_add_dataloader_idx: + name = "val_loss" + else: + name = f"val_loss_custom_naming_{args[-1]}" + + self.log(name, output["val_loss"], auto_add_dataloader_idx=auto_add_dataloader_idx) + return output + + model = TestModel() + model.val_dataloader = model.val_dataloader__multiple + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=5 + ) + trainer.fit(model) + logged = trainer.logged_metrics + + if auto_add_dataloader_idx: + assert 'val_loss/dataloader_idx_0' in logged + assert 'val_loss/dataloader_idx_1' in logged + else: + assert 'val_loss_custom_naming_0' in logged + assert 'val_loss_custom_naming_1' in logged From 0783bb604c8cfe18dc1a1745cc1d85b90c500ad4 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 1 Mar 2021 20:39:25 +0100 Subject: [PATCH 3/8] remove unrelated changes --- docs/source/extensions/metrics.rst | 15 +++++++-------- pytorch_lightning/metrics/metric.py | 20 +++----------------- 2 files changed, 10 insertions(+), 25 deletions(-) diff --git a/docs/source/extensions/metrics.rst b/docs/source/extensions/metrics.rst index 27f5de1b3ead1..6a64c42ec2753 100644 --- a/docs/source/extensions/metrics.rst +++ b/docs/source/extensions/metrics.rst @@ -384,24 +384,23 @@ inside your LightningModule def __init__(self): ... - self.train_metrics = pl.metrics.MetricCollection(Accuracy(), Precision(), Recall(), prefix='train_') - self.valid_metrics = pl.metrics.MetricCollection(Accuracy(), Precision(), Recall(), prefix='val_') + metrics = pl.metrics.MetricCollection(...) + self.train_metrics = metrics.clone() + self.valid_metrics = metrics.clone() def training_step(self, batch, batch_idx): logits = self(x) ... - output = self.train_metrics(logits, y) + self.train_metrics(logits, y) # use log_dict instead of log - # metrics are logged with keys: train_Accuracy, train_Precision and train_Recall - self.log_dict(output) + self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train') def validation_step(self, batch, batch_idx): logits = self(x) ... - output = self.valid_metrics(logits, y) + self.valid_metrics(logits, y) # use log_dict instead of log - # metrics are logged with keys: val_Accuracy, val_Precision and val_Recall - self.log_dict(output) + self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val') .. note:: diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 6aa2764c59363..ab198356f7279 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -528,8 +528,6 @@ class MetricCollection(nn.ModuleDict): dict as key for output dict. Use this format if you want to chain together multiple of the same metric with different parameters. - prefix: a string to append in front of the keys of the output dict - Example (input as list): >>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall @@ -550,10 +548,7 @@ class MetricCollection(nn.ModuleDict): """ - def __init__( - self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]], - prefix: Optional[str] = None - ): + def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): super().__init__() if isinstance(metrics, dict): # Check all values are metrics @@ -578,19 +573,13 @@ def __init__( else: raise ValueError("Unknown input to MetricCollection.") - if prefix is not None: - if isinstance(prefix, str): - self.prefix = prefix - else: - raise ValueError('Expected input `prefix` to be a string') - def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 """ Iteratively call forward for each metric. Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs) will be filtered based on the signature of the individual metric. """ - return {self._set_prefix(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} + return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} def update(self, *args, **kwargs): # pylint: disable=E0202 """ @@ -603,7 +592,7 @@ def update(self, *args, **kwargs): # pylint: disable=E0202 m.update(*args, **m_kwargs) def compute(self) -> Dict[str, Any]: - return {self._set_prefix(k): m.compute() for k, m in self.items()} + return {k: m.compute() for k, m in self.items()} def reset(self): """ Iteratively call reset for each metric """ @@ -620,6 +609,3 @@ def persistent(self, mode: bool = True): """ for _, m in self.items(): m.persistent(mode) - - def _set_prefix(self, k): - return k if self.prefix is None else self.prefix+k \ No newline at end of file From 1fc20b9c7122e35310199b4466fc24055c6aa83b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 1 Mar 2021 20:42:13 +0100 Subject: [PATCH 4/8] remove --- tests/metrics/test_metric_lightning.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index b19d2908e30b7..895305fa9da7e 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -149,26 +149,26 @@ def training_step(self, batch, batch_idx): def test_metric_collection_lightning_log(tmpdir): - """ test that logging in lightning works with the MetricCollection class """ + class TestModel(BoringModel): def __init__(self): super().__init__() - self.metric = MetricCollection([SumMetric(), DiffMetric()], prefix='train_') + self.metric = MetricCollection([SumMetric(), DiffMetric()]) self.sum = 0.0 self.diff = 0.0 def training_step(self, batch, batch_idx): x = batch - self.metric(x.sum()) + metric_vals = self.metric(x.sum()) self.sum += x.sum() self.diff -= x.sum() - self.log_dict(self.metric) + self.log_dict({f'{k}_step': v for k, v in metric_vals.items()}) return self.step(x) def training_epoch_end(self, outputs): metric_vals = self.metric.compute() - self.log_dict(metric_vals) + self.log_dict({f'{k}_epoch': v for k, v in metric_vals.items()}) model = TestModel() model.val_dataloader = None @@ -184,5 +184,5 @@ def training_epoch_end(self, outputs): trainer.fit(model) logged = trainer.logged_metrics - assert torch.allclose(torch.tensor(logged["train_SumMetric"]), model.sum) - assert torch.allclose(torch.tensor(logged["train_DiffMetric"]), model.diff) + assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum) + assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff) From b0d5ce4cc4841f836ec21a5dda0229205509316d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 2 Mar 2021 10:03:24 +0100 Subject: [PATCH 5/8] remove auto --- pytorch_lightning/core/lightning.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3b06bd6c39a45..9786e7f8dc494 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -226,7 +226,7 @@ def log( sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, - auto_add_dataloader_idx: bool = True, + add_dataloader_idx: bool = True, ): """ Log a key, value @@ -262,7 +262,7 @@ def log( sync_dist: if True, reduces the metric across GPUs/TPUs sync_dist_op: the op to sync across GPUs/TPUs sync_dist_group: the ddp group to sync across - auto_add_dataloader_idx: if True, appends the index of the current dataloader to + add_dataloader_idx: if True, appends the index of the current dataloader to the name (when using multiple). If False, user needs to give unique names for each dataloader to not mix values """ @@ -295,8 +295,9 @@ def log( ) training_type_plugin = self.trainer.training_type_plugin - - dataloader_idx = self._current_dataloader_idx if auto_add_dataloader_idx else None + + # Determine if dataloader index should be added + dataloader_idx = self._current_dataloader_idx if add_dataloader_idx else None self._results.log( name, @@ -331,7 +332,7 @@ def log_dict( sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, - auto_add_dataloader_idx: bool = True, + add_dataloader_idx: bool = True, ): """ Log a dictonary of values at once @@ -354,7 +355,7 @@ def log_dict( sync_dist: if True, reduces the metric across GPUs/TPUs sync_dist_op: the op to sync across GPUs/TPUs sync_dist_group: the ddp group sync across - auto_add_dataloader_idx: if True, appends the index of the current dataloader to + add_dataloader_idx: if True, appends the index of the current dataloader to the name (when using multiple). If False, user needs to give unique names for each dataloader to not mix values """ @@ -373,7 +374,7 @@ def log_dict( sync_dist_op=sync_dist_op, tbptt_pad_token=tbptt_pad_token, tbptt_reduce_fx=tbptt_reduce_fx, - auto_add_dataloader_idx=auto_add_dataloader_idx + auto_add_dataloader_idx=add_dataloader_idx ) def write_prediction( From eb51acf75512903f5f0419d2322154cc74584dd6 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 2 Mar 2021 10:26:36 +0100 Subject: [PATCH 6/8] change to boringmodel --- pytorch_lightning/core/lightning.py | 2 +- .../trainer/logging_/test_logger_connector.py | 21 +++++++++++-------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9786e7f8dc494..73543d7dbe40c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -295,7 +295,7 @@ def log( ) training_type_plugin = self.trainer.training_type_plugin - + # Determine if dataloader index should be added dataloader_idx = self._current_dataloader_idx if add_dataloader_idx else None diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 65fecb16223ab..b14abf42b447f 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -28,7 +28,6 @@ from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import EvalModelTemplate from tests.helpers.boring_model import BoringModel, RandomDataset @@ -473,24 +472,27 @@ def training_step(self, *args, **kwargs): trainer.fit(model) -@pytest.mark.parametrize("auto_add_dataloader_idx", [False, True]) -def test_auto_add_dataloader_idx(tmpdir, auto_add_dataloader_idx): +@pytest.mark.parametrize("add_dataloader_idx", [False, True]) +def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx): """ test that auto_add_dataloader_idx argument works """ - class TestModel(EvalModelTemplate): + class TestModel(BoringModel): + def val_dataloader(self): + dl = super().val_dataloader() + return [dl, dl] def validation_step(self, *args, **kwargs): - output = super().validation_step(*args, **kwargs) - if auto_add_dataloader_idx: + output = super().validation_step(*args[:-1], **kwargs) + if add_dataloader_idx: name = "val_loss" else: name = f"val_loss_custom_naming_{args[-1]}" - self.log(name, output["val_loss"], auto_add_dataloader_idx=auto_add_dataloader_idx) + self.log(name, output["x"], add_dataloader_idx=add_dataloader_idx) return output model = TestModel() - model.val_dataloader = model.val_dataloader__multiple + model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, @@ -499,7 +501,8 @@ def validation_step(self, *args, **kwargs): trainer.fit(model) logged = trainer.logged_metrics - if auto_add_dataloader_idx: + # Check that the correct keys exist + if add_dataloader_idx: assert 'val_loss/dataloader_idx_0' in logged assert 'val_loss/dataloader_idx_1' in logged else: From 8ffe97201d298aac1386f60ceb60cea2eaf3fa51 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 2 Mar 2021 10:29:59 +0100 Subject: [PATCH 7/8] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 98332ee496fca..2888ca7cf0acb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) +- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) + + ### Changed From 79b0eae27564d2cec9e0ff2b4c44d351f395d80d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 2 Mar 2021 13:34:37 +0100 Subject: [PATCH 8/8] fix tests --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 00a7b57f6549d..743e5b5425c1e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -373,7 +373,7 @@ def log_dict( sync_dist_op=sync_dist_op, tbptt_pad_token=tbptt_pad_token, tbptt_reduce_fx=tbptt_reduce_fx, - auto_add_dataloader_idx=add_dataloader_idx + add_dataloader_idx=add_dataloader_idx ) def write_prediction(