Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 28, 2024
1 parent 8da0ed7 commit d74940a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
44 changes: 41 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4849,19 +4849,57 @@ def copy(self):
"""
return self.clone(recurse=False)

def to_padded_tensor(self, padding=0.0):
"""Converts all nested tensors to a padded version and adapts the batch-size accordingly."""
def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None):
"""Converts all nested tensors to a padded version and adapts the batch-size accordingly.
Args:
padding (float): the padding value for the tensors in the tensordict.
Defaults to ``0.0``.
mask_key (NestedKey, optional): if provided, the key where a
mask for valid values will be written.
Will result in an error if the heterogeneous dimension
isn't part of the tensordict batch-size.
Defaults to ``None``
"""
batch_size = self.batch_size
if any(shape == -1 for shape in batch_size):
new_batch_size = []
else:
new_batch_size = None
if mask_key is not None:
raise RuntimeError(
"mask_key should only be provided if the "
"heterogenous dimension is part of the batch-size."
)
padded_names = []

def to_padded(name, x):
if x.is_nested:
padded_names.append(name)
return torch.nested.to_padded_tensor(x, padding=padding)
return x

result = self._apply_nest(
lambda x: torch.nested.to_padded_tensor(x, padding=padding) if x.is_nested else x,
to_padded,
batch_size=new_batch_size,
named=True,
nested_keys=True,
)
if new_batch_size is not None:
result = result.auto_batch_size_(batch_dims=self.batch_dims)

if mask_key:
# take the first of the padded keys
padded_key = padded_names[0]
# write the mask
val = self.get(padded_key)
val = torch.nested.to_padded_tensor(
torch.ones_like(val, dtype=torch.bool), padding=False
)
if val.ndim > result.ndim:
val = val.flatten(result.ndim, -1)[..., -1].clone()
result.set(mask_key, val)
return result

def as_tensor(self):
Expand Down
16 changes: 16 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,6 +1862,22 @@ def hook(
assert set(params_reg.flatten_keys(".").keys()) == set(sd.keys())
assert_allclose_td(params_reg.flatten_keys("."), TensorDict(sd, []))

@pytest.mark.parametrize("mask_key", [None, "mask"])
def test_to_padded_tensor(self, mask_key):
td = TensorDict(
{
"nested": torch.nested.nested_tensor(
[torch.ones(3, 4, 5), torch.ones(3, 6, 5)]
)
},
batch_size=[2, 3, -1],
)
assert td.shape == torch.Size([2, 3, -1])
td_padded = td.to_padded_tensor(padding=0, mask_key=mask_key)
assert td_padded.shape == torch.Size([2, 3, 6])
if mask_key:
assert (td_padded[td_padded["mask"]] != 0).all()

def test_unbind_batchsize(self):
td = TensorDict({"a": TensorDict({"b": torch.zeros(2, 3)}, [2, 3])}, [2])
td["a"].batch_size
Expand Down

0 comments on commit d74940a

Please sign in to comment.