Skip to content

Commit

Permalink
add to doc
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Dec 1, 2023
1 parent 3c25eba commit 3917fee
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 113 deletions.
4 changes: 4 additions & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,11 @@ algorithms, such as DQN, DDPG or Dreamer.
DistributionalDQNnet
DreamerActor
DuelingCnnDQNet
GRUCell
GRU
GRUModule
LSTMCell
LSTM
LSTMModule
ObsDecoder
ObsEncoder
Expand Down
130 changes: 17 additions & 113 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,63 +28,12 @@
class LSTMCell(RNNCellBase):
r"""A long short-term memory (LSTM) cell that performs the same operation as nn.LSTMCell but is fully coded in Python.
.. math::
\begin{array}{ll}
i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
c' = f * c + i * g \\
h' = o * \tanh(c') \\
\end{array}
where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
Args:
input_size: The number of expected features in the input `x`
hidden_size: The number of features in the hidden state `h`
bias: If ``False``, then the layer does not use bias weights `b_ih` and
`b_hh`. Default: ``True``
Inputs: input, (h_0, c_0)
- **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features
- **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state
- **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state
If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
Outputs: (h_1, c_1)
- **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state
- **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state
Attributes:
weight_ih: the learnable input-hidden weights, of shape
`(4*hidden_size, input_size)`
weight_hh: the learnable hidden-hidden weights, of shape
`(4*hidden_size, hidden_size)`
bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`
bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`
.. note::
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
where :math:`k = \frac{1}{\text{hidden\_size}}`
On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
Examples::
>>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
>>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
>>> hx = torch.randn(3, 20) # (batch, hidden_size)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(input.size()[0]):
... hx, cx = rnn(input[i], (hx, cx))
... output.append(hx)
>>> output = torch.stack(output, dim=0)
This class is implemented without relying on CuDNN, which makes it compatible with :func:`functorch.vmap`.
"""

__doc__ += nn.LSTMCell.__doc__

def __init__(
self,
input_size: int,
Expand Down Expand Up @@ -149,7 +98,11 @@ def lstm_cell(self, x, hx, cx):


class LSTM(nn.LSTM):
"""A module that runs multiple steps of a multi-layer LSTM and is only coded in Python."""
"""A module that runs multiple steps of a multi-layer LSTM and is only coded in Python.
.. note::
This class is implemented without relying on CuDNN, which makes it compatible with :func:`functorch.vmap`.
"""

__doc__ += nn.LSTM.__doc__

Expand Down Expand Up @@ -692,65 +645,12 @@ def _lstm(
class GRUCell(RNNCellBase):
r"""A gated recurrent unit (GRU) cell that performs the same operation as nn.LSTMCell but is fully coded in Python.
.. math::
\begin{array}{ll}
r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
h' = (1 - z) * n + z * h
\end{array}
where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
Args:
input_size: The number of expected features in the input `x`
hidden_size: The number of features in the hidden state `h`
bias: If ``False``, then the layer does not use bias weights `b_ih` and
`b_hh`. Default: ``True``
Inputs: input, hidden
- **input** : tensor containing input features
- **hidden** : tensor containing the initial hidden
state for each element in the batch.
Defaults to zero if not provided.
Outputs: h'
- **h'** : tensor containing the next hidden state
for each element in the batch
Shape:
- input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
:math:`H_{in}` = `input_size`.
- hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
- output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
Attributes:
weight_ih: the learnable input-hidden weights, of shape
`(3*hidden_size, input_size)`
weight_hh: the learnable hidden-hidden weights, of shape
`(3*hidden_size, hidden_size)`
bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`
bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`
.. note::
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
where :math:`k = \frac{1}{\text{hidden\_size}}`
On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
Examples::
>>> rnn = nn.GRUCell(10, 20)
>>> input = torch.randn(6, 3, 10)
>>> hx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
... hx = rnn(input[i], hx)
... output.append(hx)
This class is implemented without relying on CuDNN, which makes it compatible with :func:`functorch.vmap`.
"""

__doc__ += nn.GRUCell.__doc__

def __init__(
self,
input_size: int,
Expand Down Expand Up @@ -809,9 +709,13 @@ def gru_cell(self, x, hx):


class GRU(nn.GRU):
"""A module that runs multiple steps of a multi-layer GRU network and is only coded in Python."""
"""A module that runs multiple steps of a multi-layer GRU network and is only coded in Python.
__doc__ += nn.LSTM.__doc__
.. note::
This class is implemented without relying on CuDNN, which makes it compatible with :func:`functorch.vmap`.
"""

__doc__ += nn.GRU.__doc__

def __init__(
self,
Expand Down

0 comments on commit 3917fee

Please sign in to comment.