From 941c074d5c6ca60d49672dcf7bd8f1f8e7aed5f1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 2 Dec 2023 13:07:08 +0000 Subject: [PATCH] avoid GRU/LSTM inheritance --- torchrl/modules/tensordict_module/rnn.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 57e2173a7f0..126922c4622 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -125,6 +125,17 @@ def lstm_cell(self, x, hx, cx): return hy, cy +# copy LSTM +class LSTMBase(nn.RNNBase): + def __init__(self, *args, **kwargs): + return super().__init__("LSTM", *args, **kwargs) + + +for attr in nn.LSTM.__dict__: + if attr != "__init__": + setattr(LSTMBase, attr, getattr(nn.LSTM, attr)) + + class LSTM(nn.LSTM): """A PyTorch module for executing multiple steps of a multi-layer LSTM. The module behaves exactly like :class:`torch.nn.LSTM`, but this implementation is exclusively coded in Python. @@ -799,7 +810,18 @@ def gru_cell(self, x, hx): return hy -class GRU(nn.GRU): +# copy GRU +class GRUBase(nn.RNNBase): + def __init__(self, *args, **kwargs): + return super().__init__("GRU", *args, **kwargs) + + +for attr in nn.GRU.__dict__: + if attr != "__init__": + setattr(GRUBase, attr, getattr(nn.GRU, attr)) + + +class GRU(GRUBase): """A PyTorch module for executing multiple steps of a multi-layer GRU. The module behaves exactly like :class:`torch.nn.GRU`, but this implementation is exclusively coded in Python.