Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 24, 2024
1 parent 4f47693 commit 5bf16b1
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
unravel_key_list,
)
from torch import distributed as dist, multiprocessing as mp, nn, Tensor
from torch.nn.parameter import UninitializedTensorMixin
from torch.nn.parameter import Parameter, UninitializedTensorMixin
from torch.utils._pytree import tree_map

try:
Expand Down Expand Up @@ -3196,21 +3196,22 @@ def count_bytes(tensor):
if isinstance(tensor, MemoryMappedTensor):
add(tensor)
return
if type(tensor) is not torch.Tensor:
try:
attrs, ctx = tensor.__tensor_flatten__()
for attr in attrs:
t = getattr(tensor, attr)
count_bytes(t)
return
except AttributeError:
warnings.warn(
"The sub-tensor doesn't ot have a __tensor_flatten__ attribute, making it "
"impossible to count the bytes it contains. Falling back on regular count.",
category=UserWarning,
)
count_bytes(torch.as_tensor(tensor))
return
if type(tensor) in (Tensor, Parameter, Buffer):
pass
elif hasattr(tensor, "__tensor_flatten__"):
attrs, ctx = tensor.__tensor_flatten__()
for attr in attrs:
t = getattr(tensor, attr)
count_bytes(t)
return
else:
warnings.warn(
"The sub-tensor doesn't ot have a __tensor_flatten__ attribute, making it "
"impossible to count the bytes it contains. Falling back on regular count.",
category=UserWarning,
)
count_bytes(torch.as_tensor(tensor))
return

grad = getattr(tensor, "grad", None)
if grad is not None:
Expand Down

0 comments on commit 5bf16b1

Please sign in to comment.