Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 28, 2023
1 parent 33de0fc commit 492a884
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
5 changes: 3 additions & 2 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ class TestParallel:
@pytest.mark.parametrize("hetero", [True, False])
@pytest.mark.parametrize("pdevice", [None, "cpu", "cuda"])
@pytest.mark.parametrize("edevice", ["cpu", "cuda"])
def test_parallel_devices(self, parallel, hetero, pdevice, edevice):
@pytest.mark.parametrize("bwad", [True, False])
def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad):
if parallel:
cls = ParallelEnv
else:
Expand All @@ -375,7 +376,7 @@ def test_parallel_devices(self, parallel, hetero, pdevice, edevice):
env2 = lambda: TransformedEnv(ContinuousActionVecMockEnv(device=edevice))
env = cls(2, [env1, env2], device=pdevice)

r = env.rollout(2)
r = env.rollout(2, break_when_any_done=bwad)
if pdevice is not None:
assert env.device == torch.device(pdevice)
assert r.device == torch.device(pdevice)
Expand Down
12 changes: 7 additions & 5 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,13 +450,11 @@ def _create_td(self) -> None:
# Multi-task: we share tensordict that *may* have different keys
# LazyStacked already stores this so we don't need to do anything
self.shared_tensordicts = self.shared_tensordict_parent
if self.device.type == "cpu":
if self.shared_tensordict_parent.device.type == "cpu":
if self._share_memory:
for td in self.shared_tensordicts:
td.share_memory_()
self.shared_tensordict_parent.share_memory_()
elif self._memmap:
for td in self.shared_tensordicts:
td.memmap_()
self.shared_tensordict_parent.memmap_()
else:
if self._share_memory:
self.shared_tensordict_parent.share_memory_()
Expand Down Expand Up @@ -946,6 +944,10 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
out = out.clone()
else:
out = out.to(device, non_blocking=True)
assert all(
val.device == device for val in
out.values(True, True)
)
return out

@_check_start
Expand Down

0 comments on commit 492a884

Please sign in to comment.