diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index a6ae28204a7..82ac3a081a0 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -220,7 +220,7 @@ def _lstm_cell(x, hx, cx, weight_ih, bias_ih, weight_hh, bias_hh): # gates = F.linear(x, weight_ih, bias_ih) + F.linear(hx, weight_hh, bias_hh) if bias_ih is not None: - gates = x @ weight_ih.T + bias_ih + hx @ weight_ih.T + bias_hh + gates = x @ weight_ih.T + bias_ih + hx @ weight_hh.T + bias_hh else: gates = x @ weight_ih.T + hx @ weight_hh.T