Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecatetruncated_bptt_steps flag on Trainer in favor of same setting on the LightningModule #7323

Merged
merged 17 commits into from
May 5, 2021
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added


- Added `LightningModule.truncated_bptt_steps` property ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))


- Added support for the `EarlyStopping` callback to run at the end of the training epoch ([#6944](https://github.com/PyTorchLightning/pytorch-lightning/pull/6944/))


Expand Down Expand Up @@ -196,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated


- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))


- Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292))


Expand Down
18 changes: 13 additions & 5 deletions docs/source/advanced/sequences.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,21 @@ For example, it may save memory to use Truncated Backpropagation Through Time wh

Lightning can handle TBTT automatically via this flag.

.. testcode::
.. testcode:: python

from pytorch_lightning import LightningModule

class MyModel(LightningModule):

# DEFAULT (single backwards pass per batch)
trainer = Trainer(truncated_bptt_steps=None)
def __init__(self):
super().__init__()
# Important: This property activates truncated backpropagation through time
# Setting this value to 2 splits the batch into sequences of size 2
self.truncated_bptt_steps = 2

# (split batch into sequences of size 2)
trainer = Trainer(truncated_bptt_steps=2)
def training_step(batch, batch_idx, hiddens):
# The training_step will be passed a `hiddens` argument for the split batch
...

.. note:: If you need to modify how the batch is split,
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ def training_step(
- batch_idx (int): Integer displaying index of this batch
- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
- hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.

:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
"""
args[0] = self.to_device(args[0])

Expand Down
27 changes: 23 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class LightningModule(
Module,
):
# Below is for property support of JIT in PyTorch 1.7
# since none of them is important when using JIT, we are going to ignore them.
# since none of these are important when using JIT, we are going to ignore them.
__jit_unused_properties__ = [
"datamodule",
"example_input_array",
Expand All @@ -72,6 +72,8 @@ class LightningModule(
"local_rank",
"logger",
"model_size",
"automatic_optimization",
"truncated_bptt_steps",
] + DeviceDtypeModuleMixin.__jit_unused_properties__

def __init__(self, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._current_hook_fx_name: Optional[str] = None
self._current_dataloader_idx: Optional[int] = None
self._automatic_optimization: bool = True
self._truncated_bptt_steps: int = 0
self._param_requires_grad_state = dict()

def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
Expand Down Expand Up @@ -191,6 +194,18 @@ def automatic_optimization(self) -> bool:
def automatic_optimization(self, automatic_optimization: bool) -> None:
self._automatic_optimization = automatic_optimization

@property
def truncated_bptt_steps(self) -> int:
"""
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much a longer sequence.
If this is > 0, the training step is passed ``hiddens``.
"""
return self._truncated_bptt_steps

@truncated_bptt_steps.setter
def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None:
self._truncated_bptt_steps = truncated_bptt_steps

@property
def logger(self):
""" Reference to the logger object in the Trainer. """
Expand Down Expand Up @@ -524,7 +539,8 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
batch_idx (int): Integer displaying index of this batch
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0

ananthsub marked this conversation as resolved.
Show resolved Hide resolved
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

Return:
Any of.
Expand Down Expand Up @@ -1444,7 +1460,8 @@ def tbptt_split_batch(self, batch, split_size):
Note:
Called in the training loop after
:meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
if :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0.
if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0

ananthsub marked this conversation as resolved.
Show resolved Hide resolved
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
Each returned batch split is passed separately to :meth:`training_step`.

"""
Expand Down Expand Up @@ -1545,7 +1562,9 @@ def get_progress_bar_dict(self):
if avg_training_loss is not None:
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"

if self.trainer.truncated_bptt_steps is not None:
module_tbptt_enabled = self.truncated_bptt_steps > 0
tchaton marked this conversation as resolved.
Show resolved Hide resolved
trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0
if module_tbptt_enabled or trainer_tbptt_enabled:
tqdm_dict["split_idx"] = self.trainer.split_idx

if self.trainer.logger is not None and self.trainer.logger.version is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union

from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand All @@ -23,12 +26,12 @@ def __init__(self, trainer):

def on_trainer_init(
self,
gradient_clip_val,
gradient_clip_algorithm,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
terminate_on_nan,
gradient_clip_val: float,
gradient_clip_algorithm: str,
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
track_grad_norm: Union[int, float, str],
accumulate_grad_batches: Union[int, Dict[int, int], List[list]],
truncated_bptt_steps: Optional[int],
terminate_on_nan: bool,
):

self.trainer.terminate_on_nan = terminate_on_nan
Expand All @@ -48,6 +51,11 @@ def on_trainer_init(
self.trainer.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

if truncated_bptt_steps is not None and truncated_bptt_steps > 0:
rank_zero_deprecation(
"Trainer.truncated_bptt_steps is deprecated in v1.3 and will be removed in v1.5."
" Set truncated_bptt_steps directly on the LightningModule instead."
)
self.trainer.truncated_bptt_steps = truncated_bptt_steps

def configure_accumulated_gradients(self, accumulate_grad_batches):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ def __init__(

track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.

truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer
sequence.
truncated_bptt_steps: Deprecated in v1.3 to be removed in 1.5.
Please use :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` instead.

val_check_interval: How often to check the validation set. Use float to check within a training epoch,
use int to check every n steps (batches).
Expand Down
31 changes: 25 additions & 6 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from contextlib import contextmanager, suppress
from copy import copy, deepcopy
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -432,12 +432,13 @@ def _track_gradient_norm(self):
grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm)
return grad_norm_dict

def tbptt_split_batch(self, batch):
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
splits = [batch]
if self.trainer.truncated_bptt_steps is not None:
truncated_bptt_enabled = self._truncated_bptt_enabled()
if truncated_bptt_enabled:
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
model_ref = self.trainer.lightning_module
with self.trainer.profiler.profile("tbptt_split_batch"):
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
splits = model_ref.tbptt_split_batch(batch, self._truncated_bptt_steps())
return splits

def run_training_epoch(self):
Expand Down Expand Up @@ -612,7 +613,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

# lightning module hook
splits = self.tbptt_split_batch(batch)
splits = self._tbptt_split_batch(batch)

for split_idx, split_batch in enumerate(splits):

Expand Down Expand Up @@ -876,11 +877,29 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
)

# pass hiddens if using tbptt
if self.trainer.truncated_bptt_steps is not None:
if self._truncated_bptt_enabled():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall this also be a property?

Suggested change
if self._truncated_bptt_enabled():
if self.truncated_bptt_enabled:

args.append(hiddens)

return args

def _truncated_bptt_enabled(self) -> bool:
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _truncated_bptt_enabled(self) -> bool:
@property
def truncated_bptt_enabled(self) -> bool:

""" Temporary tbptt utilities until this flag is fully migrated to the lightning module. """
module = self.trainer.lightning_module
module_tbptt_enabled = module.truncated_bptt_steps > 0
if module_tbptt_enabled:
return True

trainer = self.trainer
trainer_tbptt_enabled = trainer.truncated_bptt_steps is not None and trainer.truncated_bptt_steps > 0
return trainer_tbptt_enabled

def _truncated_bptt_steps(self) -> Optional[int]:
lightning_module = self.trainer.lightning_module
# Give precedence to the LightningModule as the Trainer flag will be removed in v1.5
if lightning_module.truncated_bptt_steps > 0:
return lightning_module.truncated_bptt_steps
return self.trainer.truncated_bptt_steps
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

def save_loggers_on_train_batch_end(self):
# when loggers should save to disk
should_flush_logs = self.trainer.logger_connector.should_flush_logs
Expand Down
5 changes: 5 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,8 @@ def test_v1_5_0_datamodule_setter():
model.datamodule = datamodule
with pytest.deprecated_call(match="The `LightningModule.datamodule`"):
_ = model.datamodule


def test_v1_5_0_trainer_tbptt_steps(tmpdir):
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
_ = Trainer(truncated_bptt_steps=1)