Skip to content

Commit

Permalink
fix lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Dec 1, 2023
1 parent 2ca1daa commit 1b88fcc
Showing 1 changed file with 3 additions and 21 deletions.
24 changes: 3 additions & 21 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _lstm(self, x, hx):

# should check self.batch_first
bs, seq_len, input_size = x.size()
h_t, c_t = [h.clone() for h in hx]
h_t, c_t = [list(h.unbind(0)) for h in hx]

outputs = []
for t in range(seq_len):
Expand All @@ -235,7 +235,7 @@ def _lstm(self, x, hx):
)

# Apply dropout if in training mode
if layer < self.num_layers - 1:
if layer < self.num_layers - 1 and self.dropout:
x_t = F.dropout(h_t[layer], p=self.dropout, training=self.training)
else: # No dropout after the last layer
x_t = h_t[layer]
Expand All @@ -248,7 +248,7 @@ def _lstm(self, x, hx):
1, 0, 2
) # Change back (batch, seq_len, features) to (seq_len, batch, features)

return outputs, (h_t, c_t)
return outputs, (torch.stack(h_t, 0), torch.stack(c_t, 0))

def forward(self, input, hx=None): # noqa: F811
self._update_flat_weights()
Expand Down Expand Up @@ -857,24 +857,6 @@ def _gru_cell(x, hx, weight_ih, bias_ih, weight_hh, bias_hh):

return hy

# @staticmethod
# def _gru_cell(x, hx, weight_ih, bias_ih, weight_hh, bias_hh):
# x = x.view(-1, x.size(1))
#
# gates_ih = F.linear(x, weight_ih, bias_ih)
# gates_hh = F.linear(hx, weight_hh, bias_hh)
#
# r_gate_ih, z_gate_ih, n_gate_ih = gates_ih.chunk(3, 1)
# r_gate_hh, z_gate_hh, n_gate_hh = gates_hh.chunk(3, 1)
#
# r_gate = (r_gate_ih + r_gate_hh).sigmoid()
# z_gate = (z_gate_ih + z_gate_hh).sigmoid()
# n_gate = (n_gate_ih + r_gate * n_gate_hh).tanh()
#
# hy = (1 - z_gate) * n_gate + z_gate * hx
#
# return hy

def _gru(self, x, hx):

if not self.batch_first:
Expand Down

0 comments on commit 1b88fcc

Please sign in to comment.