diff --git a/docs/source-pytorch/common/index.rst b/docs/source-pytorch/common/index.rst index 738e971aec532..42f7adcc2ed24 100644 --- a/docs/source-pytorch/common/index.rst +++ b/docs/source-pytorch/common/index.rst @@ -202,6 +202,13 @@ How-to Guides :col_css: col-md-4 :height: 180 +.. displayitem:: + :header: Truncated Back-Propagation Through Time + :description: Efficiently step through time when training recurrent models + :button_link: ../common/tbtt.html + :col_css: col-md-4 + :height: 180 + .. raw:: html diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst new file mode 100644 index 0000000000000..063ef8c33d319 --- /dev/null +++ b/docs/source-pytorch/common/tbptt.rst @@ -0,0 +1,59 @@ +############################################## +Truncated Backpropagation Through Time (TBPTT) +############################################## + +Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of +a much longer sequence. This is made possible by passing training batches +split along the time-dimensions into splits of size k to the +``training_step``. In order to keep the same forward propagation behavior, all +hidden states should be kept in-between each time-dimension split. + + +.. code-block:: python + + import torch + import torch.optim as optim + import pytorch_lightning as pl + from pytorch_lightning import LightningModule + + class LitModel(LightningModule): + + def __init__(self): + super().__init__() + + # 1. Switch to manual optimization + self.automatic_optimization = False + + self.truncated_bptt_steps = 10 + self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN + + # 2. Remove the `hiddens` argument + def training_step(self, batch, batch_idx): + + # 3. Split the batch in chunks along the time dimension + split_batches = split_batch(batch, self.truncated_bptt_steps) + + batch_size = 10 + hidden_dim = 20 + hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device) + for split_batch in range(split_batches): + # 4. Perform the optimization in a loop + loss, hiddens = self.my_rnn(split_batch, hiddens) + self.backward(loss) + self.optimizer.step() + self.optimizer.zero_grad() + + # 5. "Truncate" + hiddens = hiddens.detach() + + # 6. Remove the return of `hiddens` + # Returning loss in manual optimization is not needed + return None + + def configure_optimizers(self): + return optim.Adam(self.my_rnn.parameters(), lr=0.001) + + if __name__ == "__main__": + model = LitModel() + trainer = pl.Trainer(max_epochs=5) + trainer.fit(model, train_dataloader) # Define your own dataloader