Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Dec 6, 2023
1 parent b2d37e6 commit e1c38f8
Showing 1 changed file with 9 additions and 23 deletions.
32 changes: 9 additions & 23 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,7 @@ def __init__(
@staticmethod
def _lstm_cell(x, hx, cx, weight_ih, bias_ih, weight_hh, bias_hh):

# gates = F.linear(x, weight_ih, bias_ih) + F.linear(hx, weight_hh, bias_hh)
if bias_ih is not None:
gates = x @ weight_ih.T + bias_ih + hx @ weight_hh.T + bias_hh
else:
gates = x @ weight_ih.T + hx @ weight_hh.T
gates = F.linear(x, weight_ih, bias_ih) + F.linear(hx, weight_hh, bias_hh)

i_gate, f_gate, g_gate, o_gate = gates.chunk(4, 1)

Expand All @@ -246,20 +242,16 @@ def _lstm(self, x, hx):

# should check self.batch_first
bs, seq_len, input_size = x.size()
h_t, c_t = [h.unbind(0) for h in hx]
h_t_out = []
c_t_out = []

x_ts = x.unbind(1)
h_t, c_t = hx
h_t, c_t = h_t.unbind(0), c_t.unbind(0)

outputs = []
for t in range(seq_len):

x_t = x_ts[t]
for x_t in x.unbind(1):
h_t_out = []
c_t_out = []

for layer in range(self.num_layers):
for layer, weights in enumerate(self._all_weights):
# Retrieve weights
weights = self._all_weights[layer]
weight_ih = getattr(self, weights[0])
weight_hh = getattr(self, weights[1])
if self.bias is True:
Expand Down Expand Up @@ -905,14 +897,8 @@ def __init__(
def _gru_cell(x, hx, weight_ih, bias_ih, weight_hh, bias_hh):
x = x.view(-1, x.size(1))

# gate_x = F.linear(x, weight_ih, bias_ih)
# gate_h = F.linear(hx, weight_hh, bias_hh)
if bias_ih is not None:
gate_x = x @ weight_ih.T + bias_ih
gate_h = hx @ weight_hh.T + bias_hh
else:
gate_x = x @ weight_ih.T
gate_h = hx @ weight_hh.T
gate_x = F.linear(x, weight_ih, bias_ih)
gate_h = F.linear(hx, weight_hh, bias_hh)

i_r, i_i, i_n = gate_x.chunk(3, 1)
h_r, h_i, h_n = gate_h.chunk(3, 1)
Expand Down

0 comments on commit e1c38f8

Please sign in to comment.