From a87642881688c25bc8c7bc2921eab22f59f87cef Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 29 Nov 2023 12:18:47 +0100 Subject: [PATCH] add cell tests --- test/test_modules.py | 63 +++++++++++++++++++ torchrl/modules/__init__.py | 2 + torchrl/modules/tensordict_module/__init__.py | 2 +- 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/test/test_modules.py b/test/test_modules.py index ee1884c5573..e246d71ee60 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -21,6 +21,8 @@ MultiAgentConvNet, MultiAgentMLP, OnlineDTActor, + PythonGRUCell, + PythonLSTMCell, QMixer, SafeModule, TanhModule, @@ -1186,6 +1188,67 @@ def test_onlinedtactor(self, batch_dims, T=5): assert (dtactor.log_std_max > sig.log()).all() +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("bias", [True, False]) +def test_python_lstm_cell(device, bias): + + lstm_cell1 = PythonLSTMCell(10, 20, device=device, bias=bias) + lstm_cell2 = nn.LSTMCell(10, 20, device=device, bias=bias) + + # Make sure parameters match + for (k1, v1), (k2, v2) in zip( + lstm_cell1.named_parameters(), lstm_cell2.named_parameters() + ): + assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + + # Run loop + input = torch.randn(2, 3, 10).to(device) + hx1 = torch.randn(3, 20).to(device) + cx1 = torch.randn(3, 20).to(device) + hx2 = torch.randn(3, 20).to(device) + cx2 = torch.randn(3, 20).to(device) + with torch.no_grad(): + for i in range(input.size()[0]): + hx1, cx1 = lstm_cell1(input[i], (hx1, cx1)) + hx2, cx2 = lstm_cell2(input[i], (hx2, cx2)) + + # Make sure the final hidden states have the same shape + assert hx1.shape == hx2.shape + assert cx1.shape == cx2.shape + + +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("bias", [True, False]) +def test_python_gru_cell(device, bias): + + gru_cell1 = PythonGRUCell(10, 20, device=device, bias=bias) + gru_cell2 = nn.GRUCell(10, 20, device=device, bias=bias) + + # Make sure parameters match + for (k1, v1), (k2, v2) in zip( + gru_cell1.named_parameters(), gru_cell2.named_parameters() + ): + assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + + # Run loop + input = torch.randn(2, 3, 10).to(device) + hx1 = torch.randn(3, 20).to(device) + hx2 = torch.randn(3, 20).to(device) + with torch.no_grad(): + for i in range(input.size()[0]): + hx1 = gru_cell1(input[i], hx1) + hx2 = gru_cell2(input[i], hx2) + + # Make sure the final hidden states have the same shape + assert hx1.shape == hx2.shape + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 16d621f2bec..6ceeedc5780 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -61,6 +61,8 @@ LSTMModule, OrnsteinUhlenbeckProcessWrapper, ProbabilisticActor, + PythonGRUCell, + PythonLSTMCell, QValueActor, QValueHook, QValueModule, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 7605238f99a..e456684cf5c 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -31,6 +31,6 @@ SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, ) -from .rnn import GRUModule, LSTMModule +from .rnn import GRUModule, LSTMModule, PythonGRUCell, PythonLSTMCell from .sequence import SafeSequential from .world_models import WorldModelWrapper