Skip to content

Commit

Permalink
LSTMModule test
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Nov 29, 2023
1 parent 33608a8 commit 01921cf
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,7 +1779,8 @@ def test_multi_consecutive(self, shape, python_based):
td_ss["intermediate"], td["intermediate"][..., -1, :]
)

def test_lstm_parallel_env(self):
@pytest.mark.parametrize("python_based", [True, False])
def test_lstm_parallel_env(self, python_based):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

device = "cuda" if torch.cuda.device_count() else "cpu"
Expand All @@ -1791,6 +1792,7 @@ def test_lstm_parallel_env(self):
in_key="observation",
out_key="features",
device=device,
python_based=python_based,
)

def create_transformed_env():
Expand Down

0 comments on commit 01921cf

Please sign in to comment.