From 0dcbb8cd98cebf488e536ea0bad02563c0b570cc Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 21:33:07 -0700 Subject: [PATCH 01/17] deprecate-tbptt-trainer --- pytorch_lightning/core/lightning.py | 21 ++++++++++++++-- .../connectors/training_trick_connector.py | 20 ++++++++++----- pytorch_lightning/trainer/trainer.py | 3 ++- pytorch_lightning/trainer/training_loop.py | 25 ++++++++++++++----- tests/deprecated_api/test_remove_1-5.py | 5 ++++ 5 files changed, 59 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c6151d96b52dd..2992fcf8d7f6c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -59,7 +59,7 @@ class LightningModule( Module, ): # Below is for property support of JIT in PyTorch 1.7 - # since none of them is important when using JIT, we are going to ignore them. + # since none of these are important when using JIT, we are going to ignore them. __jit_unused_properties__ = [ "datamodule", "example_input_array", @@ -72,6 +72,8 @@ class LightningModule( "local_rank", "logger", "model_size", + "automatic_optimization", + "truncated_bptt_steps", ] + DeviceDtypeModuleMixin.__jit_unused_properties__ def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -104,6 +106,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._current_hook_fx_name: Optional[str] = None self._current_dataloader_idx: Optional[int] = None self._automatic_optimization: bool = True + self._truncated_bptt_steps: Optional[int] = None self._param_requires_grad_state = dict() def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: @@ -191,6 +194,18 @@ def automatic_optimization(self) -> bool: def automatic_optimization(self, automatic_optimization: bool) -> None: self._automatic_optimization = automatic_optimization + @property + def truncated_bptt_steps(self) -> Optional[int]: + """ + truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much a longer sequence. + If this is > 0, the training step is passed ``hiddens``. + """ + return self._truncated_bptt_steps + + @truncated_bptt_steps.setter + def truncated_bptt_steps(self, truncated_bptt_steps: Optional[int]) -> None: + self._truncated_bptt_steps = truncated_bptt_steps + @property def logger(self): """ Reference to the logger object in the Trainer. """ @@ -524,7 +539,9 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: batch_idx (int): Integer displaying index of this batch optimizer_idx (int): When using multiple optimizers, this argument will also be present. hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0. + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0 or + :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. + Return: Any of. diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index 4c5a036c74823..fcfc17b2db1f1 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List, Optional, Union + from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.utilities import GradClipAlgorithmType +from pytorch_lightning.utilities.distributed import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -23,12 +26,12 @@ def __init__(self, trainer): def on_trainer_init( self, - gradient_clip_val, - gradient_clip_algorithm, - track_grad_norm, - accumulate_grad_batches, - truncated_bptt_steps, - terminate_on_nan, + gradient_clip_val: float, + gradient_clip_algorithm: str, + track_grad_norm: Union[int, float, str], + accumulate_grad_batches: Union[int, Dict[int, int], List[list]], + truncated_bptt_steps: Optional[int], + terminate_on_nan: bool, ): self.trainer.terminate_on_nan = terminate_on_nan @@ -48,6 +51,11 @@ def on_trainer_init( self.trainer.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) + if truncated_bptt_steps > 0: + rank_zero_deprecation( + "Trainer.truncated_bptt_steps is deprecated in v1.3 and will be removed in v1.5." + " Set truncated_bptt_steps directly on the LightningModule instead." + ) self.trainer.truncated_bptt_steps = truncated_bptt_steps def configure_accumulated_gradients(self, accumulate_grad_batches): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d4cfd548c087b..8d29cd24d97db 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -280,7 +280,8 @@ def __init__( track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer - sequence. + sequence. This argument has been moved to LightningModule. It is deprecated here in v1.3 and + will be removed in v1.5. val_check_interval: How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches). diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f96c17a0686ce..129c3b4ee46fd 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,7 +14,7 @@ from contextlib import contextmanager, suppress from copy import copy, deepcopy -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import torch @@ -432,12 +432,13 @@ def _track_gradient_norm(self): grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm) return grad_norm_dict - def tbptt_split_batch(self, batch): + def _tbptt_split_batch(self, batch) -> List[Any]: splits = [batch] - if self.trainer.truncated_bptt_steps is not None: + truncated_bptt_enabled = self._truncated_bptt_enabled() + if truncated_bptt_enabled: model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("tbptt_split_batch"): - splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) + splits = model_ref.tbptt_split_batch(batch, self._truncated_bptt_steps()) return splits def run_training_epoch(self): @@ -612,7 +613,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) # lightning module hook - splits = self.tbptt_split_batch(batch) + splits = self._tbptt_split_batch(batch) for split_idx, split_batch in enumerate(splits): @@ -876,11 +877,23 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): ) # pass hiddens if using tbptt - if self.trainer.truncated_bptt_steps is not None: + if self._truncated_bptt_enabled(): args.append(hiddens) return args + def _truncated_bptt_enabled(self) -> bool: + """ Temporary tbptt utilities until this flag is fully migrated to the lightning module. """ + lightning_module = self.trainer.lightning_module + return self.trainer.truncated_bptt_steps > 0 or lightning_module.truncated_bptt_steps > 0 + + def _truncated_bptt_steps(self) -> Optional[int]: + lightning_module = self.trainer.lightning_module + # Give precedence to the LightningModule as the Trainer flag will be removed in v1.5 + if lightning_module.truncated_bptt_steps > 0: + return lightning_module.truncated_bptt_steps + return self.trainer.truncated_bptt_steps + def save_loggers_on_train_batch_end(self): # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 47a76b8c6db80..e2ccdaf57125d 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -389,3 +389,8 @@ def test_v1_5_0_datamodule_setter(): model.datamodule = datamodule with pytest.deprecated_call(match="The `LightningModule.datamodule`"): _ = model.datamodule + + +def test_v1_5_0_trainer_tbptt_steps(tmpdir): + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + _ = Trainer(truncated_bptt_steps=1) From b428df37f43ae0606d90f72fc71a68ffd548cec3 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 21:37:00 -0700 Subject: [PATCH 02/17] Update CHANGELOG.md --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ba281553e7648..d662dae3a43eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323)) + + - Added support for the `EarlyStopping` callback to run at the end of the training epoch ([#6944](https://github.com/PyTorchLightning/pytorch-lightning/pull/6944/)) @@ -196,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323)) + + - Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292)) From 887028c8fdfc6de9e007cb69fe2273574e921303 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 21:41:34 -0700 Subject: [PATCH 03/17] Update lightning.py --- pytorch_lightning/core/lightning.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2992fcf8d7f6c..717ee91d9db5c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -538,9 +538,9 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. batch_idx (int): Integer displaying index of this batch optimizer_idx (int): When using multiple optimizers, this argument will also be present. - hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0 or - :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. + hiddens(:class:`~torch.Tensor`): Passed in if either + :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0 + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0 Return: @@ -1461,7 +1461,9 @@ def tbptt_split_batch(self, batch, split_size): Note: Called in the training loop after :meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start` - if :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0. + if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0 + or :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0 + Each returned batch split is passed separately to :meth:`training_step`. """ @@ -1562,7 +1564,7 @@ def get_progress_bar_dict(self): if avg_training_loss is not None: tqdm_dict["loss"] = f"{avg_training_loss:.3g}" - if self.trainer.truncated_bptt_steps is not None: + if self.trainer.truncated_bptt_steps > 0 or self.truncated_bptt_steps > 0: tqdm_dict["split_idx"] = self.trainer.split_idx if self.trainer.logger is not None and self.trainer.logger.version is not None: From 2152f2273700a7865f0ce3d88dc7a99bed4edbc5 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 21:45:55 -0700 Subject: [PATCH 04/17] test --- CHANGELOG.md | 2 +- .../trainer/connectors/training_trick_connector.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d662dae3a43eb..44d094a18d971 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323)) +- Added `LightningModule.truncated_bptt_steps` property ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323)) - Added support for the `EarlyStopping` callback to run at the end of the training epoch ([#6944](https://github.com/PyTorchLightning/pytorch-lightning/pull/6944/)) diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index fcfc17b2db1f1..f27288d2b13f4 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -51,7 +51,7 @@ def on_trainer_init( self.trainer.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) - if truncated_bptt_steps > 0: + if truncated_bptt_steps is not None and truncated_bptt_steps > 0: rank_zero_deprecation( "Trainer.truncated_bptt_steps is deprecated in v1.3 and will be removed in v1.5." " Set truncated_bptt_steps directly on the LightningModule instead." From 03b062eec6a358bd9272c02b77b1148a7e9b0c36 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 21:59:59 -0700 Subject: [PATCH 05/17] Update lightning.py --- pytorch_lightning/core/lightning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 717ee91d9db5c..c8fb26ac47da0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1564,7 +1564,8 @@ def get_progress_bar_dict(self): if avg_training_loss is not None: tqdm_dict["loss"] = f"{avg_training_loss:.3g}" - if self.trainer.truncated_bptt_steps > 0 or self.truncated_bptt_steps > 0: + if (self.truncated_bptt_steps is not None and self.truncated_bptt_steps > 0 + ) or (self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0): tqdm_dict["split_idx"] = self.trainer.split_idx if self.trainer.logger is not None and self.trainer.logger.version is not None: From 3d0eb17f5c2bc8cbda545ba5ce763334eb107e6f Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 22:02:55 -0700 Subject: [PATCH 06/17] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 129c3b4ee46fd..556fbd57137e9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -885,7 +885,8 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): def _truncated_bptt_enabled(self) -> bool: """ Temporary tbptt utilities until this flag is fully migrated to the lightning module. """ lightning_module = self.trainer.lightning_module - return self.trainer.truncated_bptt_steps > 0 or lightning_module.truncated_bptt_steps > 0 + return (lightning_module.truncated_bptt_steps is not None and lightning_module.truncated_bptt_steps > 0 + ) or (self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0) def _truncated_bptt_steps(self) -> Optional[int]: lightning_module = self.trainer.lightning_module From fc44fc14f388bb6c3291601a39ecb3fda59c4566 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 22:03:24 -0700 Subject: [PATCH 07/17] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 556fbd57137e9..ab06344154472 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -891,7 +891,7 @@ def _truncated_bptt_enabled(self) -> bool: def _truncated_bptt_steps(self) -> Optional[int]: lightning_module = self.trainer.lightning_module # Give precedence to the LightningModule as the Trainer flag will be removed in v1.5 - if lightning_module.truncated_bptt_steps > 0: + if lightning_module.truncated_bptt_steps is not None and lightning_module.truncated_bptt_steps > 0: return lightning_module.truncated_bptt_steps return self.trainer.truncated_bptt_steps From e424b2a34fc69b83be3b0826211e4dcacda5df85 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 22:40:53 -0700 Subject: [PATCH 08/17] Update lightning.py --- pytorch_lightning/core/lightning.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c8fb26ac47da0..28d28bdd587a2 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1564,8 +1564,9 @@ def get_progress_bar_dict(self): if avg_training_loss is not None: tqdm_dict["loss"] = f"{avg_training_loss:.3g}" - if (self.truncated_bptt_steps is not None and self.truncated_bptt_steps > 0 - ) or (self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0): + module_tbptt_enabled = self.truncated_bptt_steps is not None and self.truncated_bptt_steps > 0 + trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0 + if module_tbptt_enabled or trainer_tbptt_enabled: tqdm_dict["split_idx"] = self.trainer.split_idx if self.trainer.logger is not None and self.trainer.logger.version is not None: From 1ba38ea5a84c101d12425072a15b43194206dd9a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 22:53:00 -0700 Subject: [PATCH 09/17] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ab06344154472..562e336e55b78 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -432,7 +432,7 @@ def _track_gradient_norm(self): grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm) return grad_norm_dict - def _tbptt_split_batch(self, batch) -> List[Any]: + def _tbptt_split_batch(self, batch: Any) -> List[Any]: splits = [batch] truncated_bptt_enabled = self._truncated_bptt_enabled() if truncated_bptt_enabled: @@ -884,9 +884,11 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): def _truncated_bptt_enabled(self) -> bool: """ Temporary tbptt utilities until this flag is fully migrated to the lightning module. """ - lightning_module = self.trainer.lightning_module - return (lightning_module.truncated_bptt_steps is not None and lightning_module.truncated_bptt_steps > 0 - ) or (self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0) + lm = self.trainer.lightning_module + trainer = self.trainer + module_tbptt_enabled = lm.truncated_bptt_steps is not None and module.truncated_bptt_steps > 0 + trainer_tbptt_enabled = trainer.truncated_bptt_steps is not None and trainer.truncated_bptt_steps > 0 + return module_tbptt_enabled or trainer_tbptt_enabled def _truncated_bptt_steps(self) -> Optional[int]: lightning_module = self.trainer.lightning_module From 074818d7d116718aa5840e72678a572a30a442dd Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 23:17:59 -0700 Subject: [PATCH 10/17] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 562e336e55b78..c0f1d6fd6b834 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -884,9 +884,10 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): def _truncated_bptt_enabled(self) -> bool: """ Temporary tbptt utilities until this flag is fully migrated to the lightning module. """ - lm = self.trainer.lightning_module + module = self.trainer.lightning_module + module_tbptt_enabled = module.truncated_bptt_steps is not None and module.truncated_bptt_steps > 0 + trainer = self.trainer - module_tbptt_enabled = lm.truncated_bptt_steps is not None and module.truncated_bptt_steps > 0 trainer_tbptt_enabled = trainer.truncated_bptt_steps is not None and trainer.truncated_bptt_steps > 0 return module_tbptt_enabled or trainer_tbptt_enabled From 020adeb54ce513a365a89a127ef04fc42edfc951 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 17:43:02 -0700 Subject: [PATCH 11/17] update docs --- docs/source/advanced/sequences.rst | 18 +++++++++++++----- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/core/lightning.py | 12 +++++------- pytorch_lightning/trainer/trainer.py | 5 ++--- pytorch_lightning/trainer/training_loop.py | 8 +++++--- 5 files changed, 26 insertions(+), 19 deletions(-) diff --git a/docs/source/advanced/sequences.rst b/docs/source/advanced/sequences.rst index 759a671cc42ef..5c6b34b16cb87 100644 --- a/docs/source/advanced/sequences.rst +++ b/docs/source/advanced/sequences.rst @@ -40,13 +40,21 @@ For example, it may save memory to use Truncated Backpropagation Through Time wh Lightning can handle TBTT automatically via this flag. -.. testcode:: +.. testcode:: python + + from pytorch_lightning import LightningModule + + class MyModel(LightningModule): - # DEFAULT (single backwards pass per batch) - trainer = Trainer(truncated_bptt_steps=None) + def __init__(self): + super().__init__() + # Important: This property activates truncated backpropagation through time + # Setting this value to 2 splits the batch into sequences of size 2 + self.truncated_bptt_steps = 2 - # (split batch into sequences of size 2) - trainer = Trainer(truncated_bptt_steps=2) + def training_step(batch, batch_idx, hiddens): + # The training_step will be passed a `hiddens` argument for the split batch + ... .. note:: If you need to modify how the batch is split, override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`. diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 3a3a409a2d7da..14bb7ad2bdd5f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -196,7 +196,7 @@ def training_step( - batch_idx (int): Integer displaying index of this batch - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0. + :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0 """ args[0] = self.to_device(args[0]) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 28d28bdd587a2..875435f05b51d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -106,7 +106,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._current_hook_fx_name: Optional[str] = None self._current_dataloader_idx: Optional[int] = None self._automatic_optimization: bool = True - self._truncated_bptt_steps: Optional[int] = None + self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: @@ -195,7 +195,7 @@ def automatic_optimization(self, automatic_optimization: bool) -> None: self._automatic_optimization = automatic_optimization @property - def truncated_bptt_steps(self) -> Optional[int]: + def truncated_bptt_steps(self) -> int: """ truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much a longer sequence. If this is > 0, the training step is passed ``hiddens``. @@ -203,7 +203,7 @@ def truncated_bptt_steps(self) -> Optional[int]: return self._truncated_bptt_steps @truncated_bptt_steps.setter - def truncated_bptt_steps(self, truncated_bptt_steps: Optional[int]) -> None: + def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None: self._truncated_bptt_steps = truncated_bptt_steps @property @@ -538,9 +538,8 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. batch_idx (int): Integer displaying index of this batch optimizer_idx (int): When using multiple optimizers, this argument will also be present. - hiddens(:class:`~torch.Tensor`): Passed in if either + hiddens(:class:`~torch.Tensor`): Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0 - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0 Return: @@ -1462,7 +1461,6 @@ def tbptt_split_batch(self, batch, split_size): Called in the training loop after :meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start` if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0 - or :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0 Each returned batch split is passed separately to :meth:`training_step`. @@ -1564,7 +1562,7 @@ def get_progress_bar_dict(self): if avg_training_loss is not None: tqdm_dict["loss"] = f"{avg_training_loss:.3g}" - module_tbptt_enabled = self.truncated_bptt_steps is not None and self.truncated_bptt_steps > 0 + module_tbptt_enabled = self.truncated_bptt_steps > 0 trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0 if module_tbptt_enabled or trainer_tbptt_enabled: tqdm_dict["split_idx"] = self.trainer.split_idx diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8d29cd24d97db..70c8a184da5e8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -279,9 +279,8 @@ def __init__( track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. - truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer - sequence. This argument has been moved to LightningModule. It is deprecated here in v1.3 and - will be removed in v1.5. + truncated_bptt_steps: Deprecated in v1.3 to be removed in 1.5. + Please use :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` instead. val_check_interval: How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches). diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c0f1d6fd6b834..fa964da83b301 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -885,16 +885,18 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): def _truncated_bptt_enabled(self) -> bool: """ Temporary tbptt utilities until this flag is fully migrated to the lightning module. """ module = self.trainer.lightning_module - module_tbptt_enabled = module.truncated_bptt_steps is not None and module.truncated_bptt_steps > 0 + module_tbptt_enabled = module.truncated_bptt_steps > 0 + if module_tbptt_enabled: + return True trainer = self.trainer trainer_tbptt_enabled = trainer.truncated_bptt_steps is not None and trainer.truncated_bptt_steps > 0 - return module_tbptt_enabled or trainer_tbptt_enabled + return trainer_tbptt_enabled def _truncated_bptt_steps(self) -> Optional[int]: lightning_module = self.trainer.lightning_module # Give precedence to the LightningModule as the Trainer flag will be removed in v1.5 - if lightning_module.truncated_bptt_steps is not None and lightning_module.truncated_bptt_steps > 0: + if lightning_module.truncated_bptt_steps > 0: return lightning_module.truncated_bptt_steps return self.trainer.truncated_bptt_steps From a6179aee726b3ca02bdf40b425f97d0133655e58 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 17:48:12 -0700 Subject: [PATCH 12/17] Update accelerator.py --- pytorch_lightning/accelerators/accelerator.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 14bb7ad2bdd5f..d7b5c31c04e49 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -195,9 +195,8 @@ def training_step( The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - batch_idx (int): Integer displaying index of this batch - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - - hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0 - + - hiddens(:class:`~torch.Tensor`): + Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0 """ args[0] = self.to_device(args[0]) From 418cb7e9bc298466f4907e742f524a4e5177c593 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 17:55:06 -0700 Subject: [PATCH 13/17] Update accelerator.py --- pytorch_lightning/accelerators/accelerator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index d7b5c31c04e49..8237f7687fe62 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -195,8 +195,8 @@ def training_step( The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - batch_idx (int): Integer displaying index of this batch - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - - hiddens(:class:`~torch.Tensor`): - Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0 + - hiddens(:class:`~torch.Tensor`): Passed in if + :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0 """ args[0] = self.to_device(args[0]) From 597c6431614783f38ed4cfa9490dc239ecc91e43 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 18:32:48 -0700 Subject: [PATCH 14/17] more docs --- docs/source/advanced/sequences.rst | 15 +++--- docs/source/common/lightning_module.rst | 61 +++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/docs/source/advanced/sequences.rst b/docs/source/advanced/sequences.rst index 5c6b34b16cb87..f010372d96c1b 100644 --- a/docs/source/advanced/sequences.rst +++ b/docs/source/advanced/sequences.rst @@ -52,16 +52,19 @@ Lightning can handle TBTT automatically via this flag. # Setting this value to 2 splits the batch into sequences of size 2 self.truncated_bptt_steps = 2 - def training_step(batch, batch_idx, hiddens): - # The training_step will be passed a `hiddens` argument for the split batch - ... + # Truncated back-propagation through time + def training_step(self, batch, batch_idx, hiddens): + # the training step must be updated to accept a ``hiddens`` argument + # hiddens are the hiddens from the previous truncated backprop step + out, hiddens = self.lstm(data, hiddens) + return { + "loss": ..., + "hiddens": hiddens + } .. note:: If you need to modify how the batch is split, override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`. -.. note:: Using this feature requires updating your LightningModule's - :meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg. - ---------- Iterable Datasets diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 64aed36e024e6..ec9bddf5b4f98 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1005,6 +1005,67 @@ Get the model file size (in megabytes) using ``self.model_size`` inside Lightnin -------------- +truncated_bptt_steps +^^^^^^^^^^^^^^^^^^^^ + +Truncated back prop breaks performs backprop every k steps of +a much longer sequence. + +If this is enabled, your batches will automatically get truncated +and the trainer will apply Truncated Backprop to it. + +(`Williams et al. "An efficient gradient-based algorithm for on-line training of +recurrent network trajectories." +`_) + +.. testcode:: python + + from pytorch_lightning import LightningModule + + class MyModel(LightningModule): + + def __init__(self): + super().__init__() + # Important: This property activates truncated backpropagation through time + # Setting this value to 2 splits the batch into sequences of size 2 + self.truncated_bptt_steps = 2 + + # Truncated back-propagation through time + def training_step(self, batch, batch_idx, hiddens): + # the training step must be updated to accept a ``hiddens`` argument + # hiddens are the hiddens from the previous truncated backprop step + out, hiddens = self.lstm(data, hiddens) + return { + "loss": ..., + "hiddens": hiddens + } + +.. note:: If you need to modify how the batch is split, + override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`. + +Lightning takes care to split your batch along the time-dimension. + +.. code-block:: python + + # we use the second as the time dimension + # (batch, time, ...) + sub_batch = batch[0, 0:t, ...] + +.. code-block:: python + + +To modify how the batch is split, +override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`: + +.. testcode:: + + class LitMNIST(LightningModule): + def tbptt_split_batch(self, batch, split_size): + # do your own splitting on the batch + return splits + +-------------- + Hooks ^^^^^ This is the pseudocode to describe how all the hooks are called during a call to ``.fit()``. From d239f990cebedc73a86e9dc63f26c709fe2fdad8 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 18:36:20 -0700 Subject: [PATCH 15/17] tweaks --- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/core/lightning.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 13 +++---------- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8237f7687fe62..87ad29a5296be 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -196,7 +196,7 @@ def training_step( - batch_idx (int): Integer displaying index of this batch - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0 + :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. """ args[0] = self.to_device(args[0]) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 875435f05b51d..3c591223f069e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -539,7 +539,7 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: batch_idx (int): Integer displaying index of this batch optimizer_idx (int): When using multiple optimizers, this argument will also be present. hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0 + :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. Return: @@ -1460,7 +1460,7 @@ def tbptt_split_batch(self, batch, split_size): Note: Called in the training loop after :meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start` - if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0 + if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. Each returned batch split is passed separately to :meth:`training_step`. diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index fa964da83b301..9b4ea160ac321 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -884,21 +884,14 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): def _truncated_bptt_enabled(self) -> bool: """ Temporary tbptt utilities until this flag is fully migrated to the lightning module. """ - module = self.trainer.lightning_module - module_tbptt_enabled = module.truncated_bptt_steps > 0 - if module_tbptt_enabled: - return True + return self._truncated_bptt_steps() > 0 - trainer = self.trainer - trainer_tbptt_enabled = trainer.truncated_bptt_steps is not None and trainer.truncated_bptt_steps > 0 - return trainer_tbptt_enabled - - def _truncated_bptt_steps(self) -> Optional[int]: + def _truncated_bptt_steps(self) -> int: lightning_module = self.trainer.lightning_module # Give precedence to the LightningModule as the Trainer flag will be removed in v1.5 if lightning_module.truncated_bptt_steps > 0: return lightning_module.truncated_bptt_steps - return self.trainer.truncated_bptt_steps + return self.trainer.truncated_bptt_steps or 0 def save_loggers_on_train_batch_end(self): # when loggers should save to disk From 148ab3306f7d412b57d3bb3264cf299a3d3142a7 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 4 May 2021 08:49:19 +0200 Subject: [PATCH 16/17] chlog --- CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 44d094a18d971..14fc7a99e4232 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added `LightningModule.truncated_bptt_steps` property ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323)) - - - Added support for the `EarlyStopping` callback to run at the end of the training epoch ([#6944](https://github.com/PyTorchLightning/pytorch-lightning/pull/6944/)) @@ -154,6 +151,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed `LightningModule.truncated_bptt_steps` to be property ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323)) + + - Changed `EarlyStopping` callback from by default running `EarlyStopping.on_validation_end` if only training is run. Set `check_on_train_epoch_end` to run the callback at the end of the train epoch instead of at the end of the validation epoch ([#7069](https://github.com/PyTorchLightning/pytorch-lightning/pull/7069)) From d1669fde876840198af78897373ce3ddd4822f98 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 09:34:17 -0700 Subject: [PATCH 17/17] comments --- docs/source/common/lightning_module.rst | 10 +++------- pytorch_lightning/core/lightning.py | 2 -- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index ec9bddf5b4f98..3865400121fe2 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1018,6 +1018,8 @@ and the trainer will apply Truncated Backprop to it. recurrent network trajectories." `_) +`Tutorial `_ + .. testcode:: python from pytorch_lightning import LightningModule @@ -1040,9 +1042,6 @@ recurrent network trajectories." "hiddens": hiddens } -.. note:: If you need to modify how the batch is split, - override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`. - Lightning takes care to split your batch along the time-dimension. .. code-block:: python @@ -1051,13 +1050,10 @@ Lightning takes care to split your batch along the time-dimension. # (batch, time, ...) sub_batch = batch[0, 0:t, ...] -.. code-block:: python - - To modify how the batch is split, override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`: -.. testcode:: +.. testcode:: python class LitMNIST(LightningModule): def tbptt_split_batch(self, batch, split_size): diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3c591223f069e..f5bb9e04c89aa 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -541,7 +541,6 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: hiddens(:class:`~torch.Tensor`): Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. - Return: Any of. @@ -1461,7 +1460,6 @@ def tbptt_split_batch(self, batch, split_size): Called in the training loop after :meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start` if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. - Each returned batch split is passed separately to :meth:`training_step`. """