diff --git a/tensordict/_td.py b/tensordict/_td.py index 0147d63ca..983384cd1 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -62,6 +62,7 @@ _NON_STR_KEY_ERR, _NON_STR_KEY_TUPLE_ERR, _parse_to, + _pass_through, _prune_selected_keys, _set_item, _set_max_batch_size, @@ -3809,7 +3810,7 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): def _get_str(self, key, default): if key in self.keys() and _is_tensor_collection(self.entry_class(key)): data = self._source._get_str(key, NO_DEFAULT) - if is_non_tensor(data): + if _pass_through(data): return data[self.idx] return _SubTensorDict(data, self.idx) return self._source._get_at_str(key, self.idx, default=default) diff --git a/tensordict/utils.py b/tensordict/utils.py index c019092bf..a43840ef8 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1788,7 +1788,7 @@ def _set_max_batch_size(source: T, batch_dims=None): """Updates a tensordict with its maximum batch size.""" from tensordict.base import _is_tensor_collection - tensor_data = [val for val in source.values() if not is_non_tensor(val)] + tensor_data = [val for val in source.values() if not _pass_through(val)] for val in tensor_data: if _is_tensor_collection(type(val)): diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 29831ccd1..343af93b5 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -11813,6 +11813,13 @@ class SubTC(NonTensorData): ... class TestUnbatchedTensor: + def test_auto_batch_size(self): + td = TensorDict(a=UnbatchedTensor(0), b=torch.randn(2, 3)).auto_batch_size_( + batch_dims=2 + ) + assert td.shape == (2, 3) + assert td["a"] == 0 + def test_unbatched(self): assert UnbatchedTensor._pass_through td = TensorDict(