From 492a884a3a0afcc81354a926cda5289287204d9d Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 28 Nov 2023 11:05:46 +0000 Subject: [PATCH] amend --- test/test_env.py | 5 +++-- torchrl/envs/batched_envs.py | 12 +++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 8bf51263147..917d7003534 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -361,7 +361,8 @@ class TestParallel: @pytest.mark.parametrize("hetero", [True, False]) @pytest.mark.parametrize("pdevice", [None, "cpu", "cuda"]) @pytest.mark.parametrize("edevice", ["cpu", "cuda"]) - def test_parallel_devices(self, parallel, hetero, pdevice, edevice): + @pytest.mark.parametrize("bwad", [True, False]) + def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad): if parallel: cls = ParallelEnv else: @@ -375,7 +376,7 @@ def test_parallel_devices(self, parallel, hetero, pdevice, edevice): env2 = lambda: TransformedEnv(ContinuousActionVecMockEnv(device=edevice)) env = cls(2, [env1, env2], device=pdevice) - r = env.rollout(2) + r = env.rollout(2, break_when_any_done=bwad) if pdevice is not None: assert env.device == torch.device(pdevice) assert r.device == torch.device(pdevice) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b6f1d9d229f..2167e9db81c 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -450,13 +450,11 @@ def _create_td(self) -> None: # Multi-task: we share tensordict that *may* have different keys # LazyStacked already stores this so we don't need to do anything self.shared_tensordicts = self.shared_tensordict_parent - if self.device.type == "cpu": + if self.shared_tensordict_parent.device.type == "cpu": if self._share_memory: - for td in self.shared_tensordicts: - td.share_memory_() + self.shared_tensordict_parent.share_memory_() elif self._memmap: - for td in self.shared_tensordicts: - td.memmap_() + self.shared_tensordict_parent.memmap_() else: if self._share_memory: self.shared_tensordict_parent.share_memory_() @@ -946,6 +944,10 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: out = out.clone() else: out = out.to(device, non_blocking=True) + assert all( + val.device == device for val in + out.values(True, True) + ) return out @_check_start