Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 18, 2024
1 parent 4225911 commit 4952c18
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
26 changes: 26 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
_KEY_ERROR,
_proc_init,
_prune_selected_keys,
_set_max_batch_size,
_shape,
_split_tensordict,
_td_fields,
Expand Down Expand Up @@ -326,6 +327,31 @@ def any(self, dim: int = None) -> bool | TensorDictBase:
"""
...

def auto_batch_size_(self, batch_dims: int | None = None) -> T:
"""Sets the maximum batch-size for the tensordict, up to an optional batch_dims.
Args:
batch_dims (int, optional): if provided, the batch-size will be at
most ``batch_dims`` long.
Returns:
self
Examples:
>>> from tensordict import TensorDict
>>> import torch
>>> td = TensorDict({"a": torch.randn(3, 4, 5), "b": {"c": torch.randn(3, 4, 6)}}, batch_size=[])
>>> td.auto_batch_size_()
>>> print(td.batch_size)
torch.Size([3, 4])
>>> td.auto_batch_size_(batch_dims=1)
>>> print(td.batch_size)
torch.Size([3])
"""
_set_max_batch_size(self, batch_dims)
return self

# Module interaction
@classmethod
def from_module(
Expand Down
4 changes: 3 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,12 +1441,14 @@ def _expand_to_match_shape(

def _set_max_batch_size(source: T, batch_dims=None):
"""Updates a tensordict with its maximium batch size."""
from tensordict import NonTensorData

tensor_data = list(source.values())

for val in tensor_data:
from tensordict.base import _is_tensor_collection

if _is_tensor_collection(val.__class__):
if _is_tensor_collection(val.__class__) and not isinstance(val, NonTensorData):
_set_max_batch_size(val, batch_dims=batch_dims)
batch_size = []
if not tensor_data: # when source is empty
Expand Down
17 changes: 17 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,6 +1784,23 @@ def test_assert(self, td_name, device):
):
assert td

def test_auto_batch_size_(self, td_name, device):
td = getattr(self, td_name)(device)
batch_size = td.batch_size
error = None
try:
td.batch_size = []
except Exception as err:
error = err
if error is not None:
with pytest.raises(type(error)):
td.auto_batch_size_()
return
td.auto_batch_size_()
assert td.batch_size[: len(batch_size)] == batch_size
td.auto_batch_size_(1)
assert len(td.batch_size) == 1

def test_broadcast(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
Expand Down

0 comments on commit 4952c18

Please sign in to comment.