Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Dec 4, 2023
1 parent 7e58cd8 commit e1d7311
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def lstm_cell(self, x, hx, cx):

# copy LSTM
class LSTMBase(nn.RNNBase):
"""A Base module for LSTM. Inheriting from LSTMBase enables compatibility with torch.compile."""

def __init__(self, *args, **kwargs):
return super().__init__("LSTM", *args, **kwargs)

Expand All @@ -137,8 +139,7 @@ def __init__(self, *args, **kwargs):


class LSTM(LSTMBase):
"""A PyTorch module for executing multiple steps of a multi-layer LSTM. The module behaves exactly like
:class:`torch.nn.LSTM`, but this implementation is exclusively coded in Python.
"""A PyTorch module for executing multiple steps of a multi-layer LSTM. The module behaves exactly like :class:`torch.nn.LSTM`, but this implementation is exclusively coded in Python.
.. note::
This class is implemented without relying on CuDNN, which makes it compatible with :func:`torch.vmap` and :func:`torch.compile`.
Expand Down Expand Up @@ -811,6 +812,8 @@ def gru_cell(self, x, hx):

# copy GRU
class GRUBase(nn.RNNBase):
"""A Base module for GRU. Inheriting from GRUBase enables compatibility with torch.compile."""

def __init__(self, *args, **kwargs):
return super().__init__("GRU", *args, **kwargs)

Expand All @@ -821,8 +824,7 @@ def __init__(self, *args, **kwargs):


class GRU(GRUBase):
"""A PyTorch module for executing multiple steps of a multi-layer GRU. The module behaves exactly like
:class:`torch.nn.GRU`, but this implementation is exclusively coded in Python.
"""A PyTorch module for executing multiple steps of a multi-layer GRU. The module behaves exactly like :class:`torch.nn.GRU`, but this implementation is exclusively coded in Python.
.. note::
This class is implemented without relying on CuDNN, which makes it compatible with :func:`torch.vmap` and :func:`torch.compile`.
Expand Down

0 comments on commit e1d7311

Please sign in to comment.