Skip to content

Commit

Permalink
GRUModule test
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Nov 29, 2023
1 parent 45c6aa2 commit d26e408
Showing 1 changed file with 1 addition and 45 deletions.
46 changes: 1 addition & 45 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def _gru_cell(x, hx, weight_ih, bias_ih, weight_hh, bias_hh):

return hy

def _gru_old(self, x, hx):
def _gru(self, x, hx):

if self.batch_first is False:
x = x.permute(
Expand Down Expand Up @@ -904,50 +904,6 @@ def _gru_old(self, x, hx):

return outputs, h_t

def _gru(self, x, hx):

if self.batch_first is False:
x = x.permute(1, 0, 2) # Change (seq_len, batch, features) to (batch, seq_len, features)

bs, seq_len, input_size = x.size()
h_t = hx.clone()

outputs = []

for t in range(seq_len):
x_t = x[:, t, :]

for layer in range(self.num_layers):

# Retrieve weights
weights = self._all_weights[layer]
weight_ih = getattr(self, weights[0])
weight_hh = getattr(self, weights[1])
if self.bias is True:
bias_ih = getattr(self, weights[2])
bias_hh = getattr(self, weights[3])
else:
bias_ih = None
bias_hh = None

h_t = self._gru_cell(
x_t, h_t, weight_ih, bias_ih, weight_hh, bias_hh
)

# Apply dropout if in training mode and not the last layer
if layer < self.num_layers - 1:
x_t = F.dropout(h_t, p=self.dropout, training=self.training)
else:
x_t = h_t

outputs.append(x_t)

outputs = torch.stack(outputs, dim=1)
if self.batch_first is False:
outputs = outputs.permute(1, 0, 2) # Change back (batch, seq_len, features) to (seq_len, batch, features)

return outputs, h_t

def forward(self, input, hx=None): # noqa: F811
self._update_flat_weights()
if input.dim() != 3:
Expand Down

0 comments on commit d26e408

Please sign in to comment.