diff --git a/test/test_modules.py b/test/test_modules.py index e246d71ee60..a8cdfca2cf5 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -22,6 +22,7 @@ MultiAgentMLP, OnlineDTActor, PythonGRUCell, + PythonLSTM, PythonLSTMCell, QMixer, SafeModule, @@ -1249,6 +1250,48 @@ def test_python_gru_cell(device, bias): assert hx1.shape == hx2.shape +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("dropout", [0.0, 0.5]) +def test_python_lstm(device, bias, dropout): + + lstm1 = PythonLSTM( + input_size=10, hidden_size=20, num_layers=2, device=device, bias=bias + ) + lstm2 = nn.LSTM( + input_size=10, hidden_size=20, num_layers=2, device=device, bias=bias + ) + + # Make sure parameters match + for (k1, v1), (k2, v2) in zip(lstm1.named_parameters(), lstm2.named_parameters()): + assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + + input = torch.randn(5, 3, 10) + h0 = torch.randn(2, 5, 20) + c0 = torch.randn(2, 5, 20) + + # Test without hidden states + with torch.no_grad(): + output1, (hn1, cn1) = lstm1(input) + output2, (hn2, cn2) = lstm2(input) + + assert hn1.shape == hn2.shape + assert cn1.shape == cn2.shape + assert output1.shape == output2.shape + + # Test with hidden states + with torch.no_grad(): + output1, (hn1, cn1) = lstm1(input, (h0, c0)) + output2, (hn2, cn2) = lstm1(input, (h0, c0)) + + assert hn1.shape == hn2.shape + assert cn1.shape == cn2.shape + assert output1.shape == output2.shape + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)