Skip to content

Commit

Permalink
LSTM module and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Nov 29, 2023
1 parent 283919b commit 2a7e5fb
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
MultiAgentMLP,
OnlineDTActor,
PythonGRUCell,
PythonLSTM,
PythonLSTMCell,
QMixer,
SafeModule,
Expand Down Expand Up @@ -1249,6 +1250,48 @@ def test_python_gru_cell(device, bias):
assert hx1.shape == hx2.shape


@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("dropout", [0.0, 0.5])
def test_python_lstm(device, bias, dropout):

lstm1 = PythonLSTM(
input_size=10, hidden_size=20, num_layers=2, device=device, bias=bias
)
lstm2 = nn.LSTM(
input_size=10, hidden_size=20, num_layers=2, device=device, bias=bias
)

# Make sure parameters match
for (k1, v1), (k2, v2) in zip(lstm1.named_parameters(), lstm2.named_parameters()):
assert k1 == k2, f"Parameter names do not match: {k1} != {k2}"
assert (
v1.shape == v2.shape
), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}"

input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 5, 20)
c0 = torch.randn(2, 5, 20)

# Test without hidden states
with torch.no_grad():
output1, (hn1, cn1) = lstm1(input)
output2, (hn2, cn2) = lstm2(input)

assert hn1.shape == hn2.shape
assert cn1.shape == cn2.shape
assert output1.shape == output2.shape

# Test with hidden states
with torch.no_grad():
output1, (hn1, cn1) = lstm1(input, (h0, c0))
output2, (hn2, cn2) = lstm1(input, (h0, c0))

assert hn1.shape == hn2.shape
assert cn1.shape == cn2.shape
assert output1.shape == output2.shape


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 comments on commit 2a7e5fb

Please sign in to comment.