Skip to content

Commit

Permalink
[Feature] TensorDict.record_stream
Browse files Browse the repository at this point in the history
ghstack-source-id: 1d9bcff8e4f6e308d8f8e9fa06b3da4eca8905f1
Pull Request resolved: #1016
  • Loading branch information
vmoens committed Oct 1, 2024
1 parent 4eff3e4 commit c8af3d5
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class TensorDictBase(MutableMapping):
_cache: bool = None
_is_non_tensor: bool = False
_memmap_prefix = None
_stream: torch.cuda.Stream | None = None

def __bool__(self) -> bool:
raise RuntimeError("Converting a tensordict to boolean value is not permitted")
Expand Down Expand Up @@ -7227,6 +7228,28 @@ def newfn(item_and_out):
out = torch.cat(imaplist, dim)
return out

# Stream
def record_stream(self, stream: torch.cuda.Stream):
"""Marks the tensordict as having been used by this stream.
When the tensordict is deallocated, ensure the tensor memory is not reused for other tensors until all work
queued on stream at the time of deallocation is complete.
See :meth:`~torch.Tensor.record_stream` for more information.`
"""
if self._stream is not None and self._stream != stream:
return RuntimeError(
"A stream is already associated with this TensorDict instance."
)
self._stream = stream

def record(tensor):
tensor.record_stream(stream)

self._fast_apply(record, filter_empty=True)
return self

# point-wise arithmetic ops
def __add__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.add(other)
Expand Down
18 changes: 18 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,6 +2404,24 @@ def test_squeeze(self, device):
td1b = torch.squeeze(td2, dim=1)
assert td1b.batch_size == td1.batch_size

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_record_stream(self):
s0 = torch.cuda.Stream(0)
s1 = torch.cuda.Stream(0)
with torch.cuda.stream(s1):
td = TensorDict(
{
"a": torch.randn(3, device="cuda:0"),
("b", "c"): torch.randn(3, device="cuda:0"),
}
)
td.record_stream(s1)
with pytest.raises(
RuntimeError,
match="A stream is already associated with this TensorDict instance",
):
td.record_stream(s0)

@pytest.mark.parametrize("device", get_available_devices())
def test_subtensordict_construction(self, device):
torch.manual_seed(1)
Expand Down

0 comments on commit c8af3d5

Please sign in to comment.