From 1f2a586053faf95979e8411f1cb894341ede4af0 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 29 Nov 2023 19:29:43 +0100 Subject: [PATCH] GRUModule test --- test/test_tensordictmodules.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 15f67ae4d2c..dbe4fa3e8e0 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -1775,6 +1775,7 @@ def test_multi_consecutive(self, shape, python_based): lstm_module_ss(td_ss) td_ss = step_mdp(td_ss, keep_other=True) td_ss["observation"][:] = _t + 1 + import ipdb; ipdb.set_trace() # asssert fails torch.testing.assert_close( td_ss["intermediate"], td["intermediate"][..., -1, :] ) @@ -2045,6 +2046,7 @@ def test_multi_consecutive(self, shape, python_based): gru_module_ss(td_ss) td_ss = step_mdp(td_ss, keep_other=True) td_ss["observation"][:] = _t + 1 + import ipdb; ipdb.set_trace() # asssert fails torch.testing.assert_close( td_ss["intermediate"], td["intermediate"][..., -1, :] )