Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed May 4, 2021
1 parent 148ab33 commit d1669fd
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
10 changes: 3 additions & 7 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,8 @@ and the trainer will apply Truncated Backprop to it.
recurrent network trajectories."
<http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)

`Tutorial <https://d2l.ai/chapter_recurrent-neural-networks/bptt.html>`_

.. testcode:: python

from pytorch_lightning import LightningModule
Expand All @@ -1040,9 +1042,6 @@ recurrent network trajectories."
"hiddens": hiddens
}

.. note:: If you need to modify how the batch is split,
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.

Lightning takes care to split your batch along the time-dimension.

.. code-block:: python
Expand All @@ -1051,13 +1050,10 @@ Lightning takes care to split your batch along the time-dimension.
# (batch, time, ...)
sub_batch = batch[0, 0:t, ...]
.. code-block:: python
To modify how the batch is split,
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`:

.. testcode::
.. testcode:: python

class LitMNIST(LightningModule):
def tbptt_split_batch(self, batch, split_size):
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,6 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
Return:
Any of.
Expand Down Expand Up @@ -1461,7 +1460,6 @@ def tbptt_split_batch(self, batch, split_size):
Called in the training loop after
:meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
Each returned batch split is passed separately to :meth:`training_step`.
"""
Expand Down

0 comments on commit d1669fd

Please sign in to comment.