-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add doc for TBPTT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove url to prevent linting error * attempt to fix linter * add tbptt.rst file * adjust doc: * nit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make example easily copy and runnable * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address comments * fix doc test warning * Update docs/source-pytorch/common/tbptt.rst --------- Co-authored-by: Alan Chu <alanchu@Alans-Air.lan> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Antiga <luca.antiga@gmail.com> Co-authored-by: Alan Chu <alanchu@Alans-Air.Home> Co-authored-by: Luca Antiga <luca@lightning.ai>
- Loading branch information
1 parent
ca59e4e
commit 1c4612e
Showing
2 changed files
with
66 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
############################################## | ||
Truncated Backpropagation Through Time (TBPTT) | ||
############################################## | ||
|
||
Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of | ||
a much longer sequence. This is made possible by passing training batches | ||
split along the time-dimensions into splits of size k to the | ||
``training_step``. In order to keep the same forward propagation behavior, all | ||
hidden states should be kept in-between each time-dimension split. | ||
|
||
|
||
.. code-block:: python | ||
import torch | ||
import torch.optim as optim | ||
import pytorch_lightning as pl | ||
from pytorch_lightning import LightningModule | ||
class LitModel(LightningModule): | ||
def __init__(self): | ||
super().__init__() | ||
# 1. Switch to manual optimization | ||
self.automatic_optimization = False | ||
self.truncated_bptt_steps = 10 | ||
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN | ||
# 2. Remove the `hiddens` argument | ||
def training_step(self, batch, batch_idx): | ||
# 3. Split the batch in chunks along the time dimension | ||
split_batches = split_batch(batch, self.truncated_bptt_steps) | ||
batch_size = 10 | ||
hidden_dim = 20 | ||
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device) | ||
for split_batch in range(split_batches): | ||
# 4. Perform the optimization in a loop | ||
loss, hiddens = self.my_rnn(split_batch, hiddens) | ||
self.backward(loss) | ||
self.optimizer.step() | ||
self.optimizer.zero_grad() | ||
# 5. "Truncate" | ||
hiddens = hiddens.detach() | ||
# 6. Remove the return of `hiddens` | ||
# Returning loss in manual optimization is not needed | ||
return None | ||
def configure_optimizers(self): | ||
return optim.Adam(self.my_rnn.parameters(), lr=0.001) | ||
if __name__ == "__main__": | ||
model = LitModel() | ||
trainer = pl.Trainer(max_epochs=5) | ||
trainer.fit(model, train_dataloader) # Define your own dataloader |