Skip to content

Commit

Permalink
avoid GRU/LSTM inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Dec 2, 2023
1 parent 71e455a commit 941c074
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ def lstm_cell(self, x, hx, cx):
return hy, cy


# copy LSTM
class LSTMBase(nn.RNNBase):
def __init__(self, *args, **kwargs):
return super().__init__("LSTM", *args, **kwargs)


for attr in nn.LSTM.__dict__:
if attr != "__init__":
setattr(LSTMBase, attr, getattr(nn.LSTM, attr))


class LSTM(nn.LSTM):
"""A PyTorch module for executing multiple steps of a multi-layer LSTM. The module behaves exactly like
:class:`torch.nn.LSTM`, but this implementation is exclusively coded in Python.
Expand Down Expand Up @@ -799,7 +810,18 @@ def gru_cell(self, x, hx):
return hy


class GRU(nn.GRU):
# copy GRU
class GRUBase(nn.RNNBase):
def __init__(self, *args, **kwargs):
return super().__init__("GRU", *args, **kwargs)


for attr in nn.GRU.__dict__:
if attr != "__init__":
setattr(GRUBase, attr, getattr(nn.GRU, attr))


class GRU(GRUBase):
"""A PyTorch module for executing multiple steps of a multi-layer GRU. The module behaves exactly like
:class:`torch.nn.GRU`, but this implementation is exclusively coded in Python.
Expand Down

0 comments on commit 941c074

Please sign in to comment.