Skip to content

Commit

Permalink
Add doc for TBPTT (#20422)
Browse files Browse the repository at this point in the history
* Add doc for TBPTT

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove url to prevent linting error

* attempt to fix linter

* add tbptt.rst file

* adjust doc:

* nit

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make example easily copy and runnable

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address comments

* fix doc test warning

* Update docs/source-pytorch/common/tbptt.rst

---------

Co-authored-by: Alan Chu <alanchu@Alans-Air.lan>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
Co-authored-by: Alan Chu <alanchu@Alans-Air.Home>
Co-authored-by: Luca Antiga <luca@lightning.ai>
  • Loading branch information
6 people authored Dec 10, 2024
1 parent ca59e4e commit 1c4612e
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docs/source-pytorch/common/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ How-to Guides
:col_css: col-md-4
:height: 180

.. displayitem::
:header: Truncated Back-Propagation Through Time
:description: Efficiently step through time when training recurrent models
:button_link: ../common/tbtt.html
:col_css: col-md-4
:height: 180

.. raw:: html

</div>
Expand Down
59 changes: 59 additions & 0 deletions docs/source-pytorch/common/tbptt.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
##############################################
Truncated Backpropagation Through Time (TBPTT)
##############################################

Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of
a much longer sequence. This is made possible by passing training batches
split along the time-dimensions into splits of size k to the
``training_step``. In order to keep the same forward propagation behavior, all
hidden states should be kept in-between each time-dimension split.


.. code-block:: python
import torch
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
class LitModel(LightningModule):
def __init__(self):
super().__init__()
# 1. Switch to manual optimization
self.automatic_optimization = False
self.truncated_bptt_steps = 10
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
# 2. Remove the `hiddens` argument
def training_step(self, batch, batch_idx):
# 3. Split the batch in chunks along the time dimension
split_batches = split_batch(batch, self.truncated_bptt_steps)
batch_size = 10
hidden_dim = 20
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
for split_batch in range(split_batches):
# 4. Perform the optimization in a loop
loss, hiddens = self.my_rnn(split_batch, hiddens)
self.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad()
# 5. "Truncate"
hiddens = hiddens.detach()
# 6. Remove the return of `hiddens`
# Returning loss in manual optimization is not needed
return None
def configure_optimizers(self):
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
if __name__ == "__main__":
model = LitModel()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_dataloader) # Define your own dataloader

0 comments on commit 1c4612e

Please sign in to comment.