Skip to content

Commit

Permalink
[BugFix] Proper auto-batch size for unbatched tensors
Browse files Browse the repository at this point in the history
ghstack-source-id: 1ad6616dfcdd55bd055512e96a1e942b27d02ec8
Pull Request resolved: #1213
  • Loading branch information
vmoens committed Feb 7, 2025
1 parent 0a8638d commit ba53d07
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
3 changes: 2 additions & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
_NON_STR_KEY_ERR,
_NON_STR_KEY_TUPLE_ERR,
_parse_to,
_pass_through,
_prune_selected_keys,
_set_item,
_set_max_batch_size,
Expand Down Expand Up @@ -3809,7 +3810,7 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT):
def _get_str(self, key, default):
if key in self.keys() and _is_tensor_collection(self.entry_class(key)):
data = self._source._get_str(key, NO_DEFAULT)
if is_non_tensor(data):
if _pass_through(data):
return data[self.idx]
return _SubTensorDict(data, self.idx)
return self._source._get_at_str(key, self.idx, default=default)
Expand Down
2 changes: 1 addition & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1788,7 +1788,7 @@ def _set_max_batch_size(source: T, batch_dims=None):
"""Updates a tensordict with its maximum batch size."""
from tensordict.base import _is_tensor_collection

tensor_data = [val for val in source.values() if not is_non_tensor(val)]
tensor_data = [val for val in source.values() if not _pass_through(val)]

for val in tensor_data:
if _is_tensor_collection(type(val)):
Expand Down
7 changes: 7 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11813,6 +11813,13 @@ class SubTC(NonTensorData): ...


class TestUnbatchedTensor:
def test_auto_batch_size(self):
td = TensorDict(a=UnbatchedTensor(0), b=torch.randn(2, 3)).auto_batch_size_(
batch_dims=2
)
assert td.shape == (2, 3)
assert td["a"] == 0

def test_unbatched(self):
assert UnbatchedTensor._pass_through
td = TensorDict(
Expand Down

0 comments on commit ba53d07

Please sign in to comment.