diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index ff4d5f9ee5c..0aedf667284 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -856,7 +856,7 @@ def _gru_cell(x, hx, weight_ih, bias_ih, weight_hh, bias_hh): return hy - def _gru_old(self, x, hx): + def _gru(self, x, hx): if self.batch_first is False: x = x.permute( @@ -904,50 +904,6 @@ def _gru_old(self, x, hx): return outputs, h_t - def _gru(self, x, hx): - - if self.batch_first is False: - x = x.permute(1, 0, 2) # Change (seq_len, batch, features) to (batch, seq_len, features) - - bs, seq_len, input_size = x.size() - h_t = hx.clone() - - outputs = [] - - for t in range(seq_len): - x_t = x[:, t, :] - - for layer in range(self.num_layers): - - # Retrieve weights - weights = self._all_weights[layer] - weight_ih = getattr(self, weights[0]) - weight_hh = getattr(self, weights[1]) - if self.bias is True: - bias_ih = getattr(self, weights[2]) - bias_hh = getattr(self, weights[3]) - else: - bias_ih = None - bias_hh = None - - h_t = self._gru_cell( - x_t, h_t, weight_ih, bias_ih, weight_hh, bias_hh - ) - - # Apply dropout if in training mode and not the last layer - if layer < self.num_layers - 1: - x_t = F.dropout(h_t, p=self.dropout, training=self.training) - else: - x_t = h_t - - outputs.append(x_t) - - outputs = torch.stack(outputs, dim=1) - if self.batch_first is False: - outputs = outputs.permute(1, 0, 2) # Change back (batch, seq_len, features) to (seq_len, batch, features) - - return outputs, h_t - def forward(self, input, hx=None): # noqa: F811 self._update_flat_weights() if input.dim() != 3: