From 3917feef92b2bcf1a390e017875bb8816d7e5b57 Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 1 Dec 2023 11:58:36 +0100 Subject: [PATCH] add to doc --- docs/source/reference/modules.rst | 4 + torchrl/modules/tensordict_module/rnn.py | 130 +++-------------------- 2 files changed, 21 insertions(+), 113 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 8d1e258502e..d859140bb70 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -333,7 +333,11 @@ algorithms, such as DQN, DDPG or Dreamer. DistributionalDQNnet DreamerActor DuelingCnnDQNet + GRUCell + GRU GRUModule + LSTMCell + LSTM LSTMModule ObsDecoder ObsEncoder diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 9d067051dee..e4c1038cea0 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -28,63 +28,12 @@ class LSTMCell(RNNCellBase): r"""A long short-term memory (LSTM) cell that performs the same operation as nn.LSTMCell but is fully coded in Python. - .. math:: - - \begin{array}{ll} - i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ - f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ - g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ - o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ - c' = f * c + i * g \\ - h' = o * \tanh(c') \\ - \end{array} - - where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. - - Args: - input_size: The number of expected features in the input `x` - hidden_size: The number of features in the hidden state `h` - bias: If ``False``, then the layer does not use bias weights `b_ih` and - `b_hh`. Default: ``True`` - - Inputs: input, (h_0, c_0) - - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features - - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state - - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state - - If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. - - Outputs: (h_1, c_1) - - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state - - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state - - Attributes: - weight_ih: the learnable input-hidden weights, of shape - `(4*hidden_size, input_size)` - weight_hh: the learnable hidden-hidden weights, of shape - `(4*hidden_size, hidden_size)` - bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)` - bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)` - .. note:: - All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` - where :math:`k = \frac{1}{\text{hidden\_size}}` - - On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. - - Examples:: - - >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) - >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size) - >>> hx = torch.randn(3, 20) # (batch, hidden_size) - >>> cx = torch.randn(3, 20) - >>> output = [] - >>> for i in range(input.size()[0]): - ... hx, cx = rnn(input[i], (hx, cx)) - ... output.append(hx) - >>> output = torch.stack(output, dim=0) + This class is implemented without relying on CuDNN, which makes it compatible with :func:`functorch.vmap`. """ + __doc__ += nn.LSTMCell.__doc__ + def __init__( self, input_size: int, @@ -149,7 +98,11 @@ def lstm_cell(self, x, hx, cx): class LSTM(nn.LSTM): - """A module that runs multiple steps of a multi-layer LSTM and is only coded in Python.""" + """A module that runs multiple steps of a multi-layer LSTM and is only coded in Python. + + .. note:: + This class is implemented without relying on CuDNN, which makes it compatible with :func:`functorch.vmap`. + """ __doc__ += nn.LSTM.__doc__ @@ -692,65 +645,12 @@ def _lstm( class GRUCell(RNNCellBase): r"""A gated recurrent unit (GRU) cell that performs the same operation as nn.LSTMCell but is fully coded in Python. - .. math:: - - \begin{array}{ll} - r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ - z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ - n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ - h' = (1 - z) * n + z * h - \end{array} - - where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. - - Args: - input_size: The number of expected features in the input `x` - hidden_size: The number of features in the hidden state `h` - bias: If ``False``, then the layer does not use bias weights `b_ih` and - `b_hh`. Default: ``True`` - - Inputs: input, hidden - - **input** : tensor containing input features - - **hidden** : tensor containing the initial hidden - state for each element in the batch. - Defaults to zero if not provided. - - Outputs: h' - - **h'** : tensor containing the next hidden state - for each element in the batch - - Shape: - - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where - :math:`H_{in}` = `input_size`. - - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden - state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. - - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. - - Attributes: - weight_ih: the learnable input-hidden weights, of shape - `(3*hidden_size, input_size)` - weight_hh: the learnable hidden-hidden weights, of shape - `(3*hidden_size, hidden_size)` - bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)` - bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)` - .. note:: - All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` - where :math:`k = \frac{1}{\text{hidden\_size}}` - - On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. - - Examples:: - - >>> rnn = nn.GRUCell(10, 20) - >>> input = torch.randn(6, 3, 10) - >>> hx = torch.randn(3, 20) - >>> output = [] - >>> for i in range(6): - ... hx = rnn(input[i], hx) - ... output.append(hx) + This class is implemented without relying on CuDNN, which makes it compatible with :func:`functorch.vmap`. """ + __doc__ += nn.GRUCell.__doc__ + def __init__( self, input_size: int, @@ -809,9 +709,13 @@ def gru_cell(self, x, hx): class GRU(nn.GRU): - """A module that runs multiple steps of a multi-layer GRU network and is only coded in Python.""" + """A module that runs multiple steps of a multi-layer GRU network and is only coded in Python. - __doc__ += nn.LSTM.__doc__ + .. note:: + This class is implemented without relying on CuDNN, which makes it compatible with :func:`functorch.vmap`. + """ + + __doc__ += nn.GRU.__doc__ def __init__( self,