Skip to content

Commit

Permalink
Merge pull request #1062 from Abdurrahheem:tests_lstm_init_no_hidden_…
Browse files Browse the repository at this point in the history
…states

Added test data and model for LSTM without hidden states initialisation
  • Loading branch information
asmorkalov authored Apr 27, 2023
2 parents 890703d + 62ebe5a commit ffa2587
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 0 deletions.
Binary file added testdata/dnn/onnx/data/input_lstm_init_h0_c0_0.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_lstm_init_h0_c0_1.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_lstm_init_h0_c0_2.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/output_lstm_init_h0_c0.npy
Binary file not shown.
26 changes: 26 additions & 0 deletions testdata/dnn/onnx/generate_onnx_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,33 @@ def forward(self, x):
lstm = LSTM(features, hidden, batch, bidirectional=True)
save_data_and_model("lstm_bidirectional", input, lstm)

class LSTM_hidden_state_inputs(nn.Module):

def __init__(self, features, hidden, batch, num_layers=1, bidirectional=False):
super(LSTM_hidden_state_inputs, self).__init__()
self.lstm = nn.LSTM(features, hidden, num_layers, bidirectional=bidirectional)

def forward(self, x, h, c):
return self.lstm(x, (h, c))[0]

batch = 1
features = 16
hidden = 8
seq_len = 2
num_layers = 1
bidirectional = False

lstm = LSTM_hidden_state_inputs(
features,
hidden,
batch,
num_layers=num_layers,
bidirectional=bidirectional
)
input = torch.randn(seq_len, batch, features)
h0 = torch.randn(num_layers + int(bidirectional), batch, hidden)
c0 = torch.randn(num_layers + int(bidirectional), batch, hidden)
save_data_and_model_multy_inputs("lstm_init_h0_c0", lstm, input, h0, c0, export_params=True)

class HiddenLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, is_bidirectional=False):
Expand Down
Binary file added testdata/dnn/onnx/models/lstm_init_h0_c0.onnx
Binary file not shown.

0 comments on commit ffa2587

Please sign in to comment.