Skip to content

Commit

Permalink
[BugFix, Feature] pad_sequence refactoring (#652)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
dtsaras and vmoens authored Feb 25, 2024
1 parent dbb3363 commit b5f6c17
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 67 deletions.
136 changes: 90 additions & 46 deletions tensordict/functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import warnings

from typing import Sequence

import torch
Expand Down Expand Up @@ -77,7 +79,8 @@ def pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0) -> T:

def pad_sequence(
list_of_tensordicts: Sequence[T],
batch_first: bool = True,
batch_first: bool | None = None,
pad_dim: int = 0,
padding_value: float = 0.0,
out: T | None = None,
device: DeviceType | None = None,
Expand All @@ -87,76 +90,117 @@ def pad_sequence(
Args:
list_of_tensordicts (List[TensorDictBase]): the list of instances to pad and stack.
batch_first (bool, optional): the ``batch_first`` correspondant of :func:`torch.nn.utils.rnn.pad_sequence`.
Defaults to ``True``.
pad_dim (int, optional): the ``pad_dim`` indicates the dimension to pad all the keys in the tensordict.
Defaults to ``0``.
padding_value (number, optional): the padding value. Defaults to ``0.0``.
out (TensorDictBase, optional): if provided, the destination where the data will be
written.
device (device compatible type, optional): if provded, the device where the
TensorDict output will be created.
return_mask (bool, optional): if ``True``, a "mask" entry will be returned.
It contains the mask of valid values in the stacked tensordict.
return_mask (bool, optional): if ``True``, a "masks" entry will be returned.
It contains a tensordict with the same structure as the stacked tensordict where every entry contains the mask of valid values with size ``torch.Size([stack_len, *new_shape])``,
where `new_shape[pad_dim] = max_seq_length` and the rest of the `new_shape` matches the previous shape of the contained tensors.
Examples:
>>> list_td = [
... TensorDict({"a": torch.zeros((3,))}, []),
... TensorDict({"a": torch.zeros((4,))}, []),
... TensorDict({"a": torch.zeros((3, 8)), "b": torch.zeros((6, 8))}, batch_size=[]),
... TensorDict({"a": torch.zeros((5, 8)), "b": torch.zeros((6, 8))}, batch_size=[]),
... ]
>>> padded_td = pad_sequence(list_td)
>>> padded_td = pad_sequence(list_td, return_mask=True)
>>> print(padded_td)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
a: Tensor(shape=torch.Size([2, 4, 8]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 5, 8]), device=cpu, dtype=torch.float32, is_shared=False),
masks: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.bool, is_shared=False),
b: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
"""
if batch_first is not None:
warnings.warn(
"The batch_first argument is deprecated and will be removed in a future release. The output will always be batch_first.",
category=DeprecationWarning,
)

if not list_of_tensordicts:
raise RuntimeError("list_of_tensordicts cannot be empty")

# check that all tensordict match
if return_mask:
list_of_tensordicts = [
td.clone(False).set("mask", torch.ones(td.shape, dtype=torch.bool))
for td in list_of_tensordicts
]
update_batch_size = True
max_seq_length = float("-inf")
keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True)
shape = max(len(td) for td in list_of_tensordicts)
if shape == 0:
shape = [
len(list_of_tensordicts),
]
elif batch_first:
shape = [len(list_of_tensordicts), shape]
else:
shape = [shape, len(list_of_tensordicts)]
if out is None:
out = TensorDict(
{}, batch_size=torch.Size(shape), device=device, _run_checks=False
)
tmp_list_of_tensordicts = []
for td in list_of_tensordicts:

if return_mask:
tmp_list_of_tensordicts.append(td.clone(False))

for key in keys:
try:
out.set(
key,
torch.nn.utils.rnn.pad_sequence(
[td.get(key) for td in list_of_tensordicts],
batch_first=batch_first,
padding_value=padding_value,
),
tensor_shape = td.get(key).shape
pos_pad_dim = pad_dim if pad_dim >= 0 else len(tensor_shape) + pad_dim

# track the maximum sequence length to update batch_size accordingly
if tensor_shape[pos_pad_dim] > max_seq_length:
max_seq_length = tensor_shape[pos_pad_dim]

# The mask should always contain the batch_size of the TensorDict
mask_shape = td.shape

# if the pad_dim is past the batch_size of the TensorDict, we need to add the new dimension to the mask
if pos_pad_dim >= td.ndim:
mask_shape += torch.Size([tensor_shape[pos_pad_dim]])
update_batch_size = False

if return_mask:
tmp_list_of_tensordicts[-1].set(
("masks", key),
torch.ones(mask_shape, dtype=torch.bool),
)
except Exception as err:
raise RuntimeError(f"pad_sequence failed for key {key}") from err
return out
else:
for key in keys:
out.set_(
if return_mask:
list_of_tensordicts = tmp_list_of_tensordicts

keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True)

old_batch_size = list(list_of_tensordicts[0].batch_size)
if update_batch_size and len(old_batch_size) > 0:
old_batch_size[pad_dim] = max_seq_length
shape = [
len(list_of_tensordicts),
] + old_batch_size

if out is None:
out = list_of_tensordicts[0].empty(recurse=True).reshape(torch.Size(shape))

for key in keys:
try:
tensor_shape = list_of_tensordicts[0].get(key).shape
pos_pad_dim = (
(pad_dim if pad_dim >= 0 else len(tensor_shape) + pad_dim)
if len(tensor_shape) > 1
else 0 # handles the case when the masks are 1-dimensional
)
out.set(
key,
torch.nn.utils.rnn.pad_sequence(
[td.get(key) for td in list_of_tensordicts],
batch_first=batch_first,
[
td.get(key).transpose(0, pos_pad_dim)
for td in list_of_tensordicts
],
batch_first=True,
padding_value=padding_value,
),
).transpose(1, pos_pad_dim + 1),
inplace=True,
)
return out
except Exception as err:
raise RuntimeError(f"pad_sequence failed for key {key}") from err
return out


def merge_tensordicts(*tensordicts: T) -> T:
Expand Down
112 changes: 91 additions & 21 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,33 +996,103 @@ def test_pad(self):
assert torch.equal(padded_td["a"], expected_a)
padded_td._check_batch_size()

@pytest.mark.parametrize("batch_first", [True, False])
@pytest.mark.parametrize("make_mask", [True, False])
def test_pad_sequence(self, batch_first, make_mask):
def test_pad_sequence_pad_dim0(self, make_mask):
pad_dim = 0
list_td = [
TensorDict({"a": torch.ones((2,)), ("b", "c"): torch.ones((2, 3))}, [2]),
TensorDict({"a": torch.ones((4,)), ("b", "c"): torch.ones((4, 3))}, [4]),
TensorDict(
{"a": torch.ones((2, 8, 8)), ("b", "c"): torch.ones((2, 3))}, [2]
),
TensorDict(
{"a": torch.full((4, 8, 8), 2), ("b", "c"): torch.full((4, 3), 2)},
[4],
),
]
padded_td = pad_sequence(
list_td, batch_first=batch_first, return_mask=make_mask
)
if batch_first:
assert padded_td.shape == torch.Size([2, 4])
assert padded_td["a"].shape == torch.Size([2, 4])
assert padded_td["a"][0, -1] == 0
assert padded_td["b", "c"].shape == torch.Size([2, 4, 3])
assert padded_td["b", "c"][0, -1, 0] == 0
padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask)
assert padded_td.shape == torch.Size(
[2, 4]
) # check the shape of the padded tensordict
assert torch.all(
padded_td["a"][0, :2, :, :] == 1
) # check the values of the first tensor
assert torch.all(
padded_td["a"][1, :, :, :] == 2
) # check the values of the second tensor
assert padded_td["a"].shape == torch.Size(
[2, 4, 8, 8]
) # check the shape of the padded tensor
assert torch.all(padded_td["a"][0, 2:, :, :] == 0) # check the padding
assert padded_td["b", "c"].shape == torch.Size(
[2, 4, 3]
) # check the shape of the padded tensor
assert torch.all(padded_td["b", "c"][0, 2:, :] == 0) # check the padding
if make_mask:
padded_td_without_masks = pad_sequence(
list_td, pad_dim=pad_dim, return_mask=False
)
assert "masks" in padded_td.keys()
assert set(
padded_td_without_masks.keys(include_nested=True, leaves_only=True)
) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True))
assert not padded_td["masks", "a"].all()
assert padded_td["masks", "a"].ndim == pad_dim + 2
assert (padded_td["a"][padded_td["masks", "a"]] != 0).all()
assert (padded_td["a"][~padded_td["masks", "a"]] == 0).all()
assert not padded_td["masks", "b", "c"].all()
assert padded_td["masks", "b", "c"].ndim == pad_dim + 2
assert (padded_td["b", "c"][padded_td["masks", "b", "c"]] != 0).all()
assert (padded_td["b", "c"][~padded_td["masks", "b", "c"]] == 0).all()
else:
assert padded_td.shape == torch.Size([4, 2])
assert padded_td["a"].shape == torch.Size([4, 2])
assert padded_td["a"][-1, 0] == 0
assert padded_td["b", "c"].shape == torch.Size([4, 2, 3])
assert padded_td["b", "c"][-1, 0, 0] == 0
assert "masks" not in padded_td.keys()

@pytest.mark.parametrize("make_mask", [True, False])
def test_pad_sequence_pad_dim1(self, make_mask):
pad_dim = 1
list_td = [
TensorDict(
{"a": torch.ones((6, 3, 8)), ("b", "c"): torch.ones((6, 3))}, [6]
),
TensorDict(
{"a": torch.full((6, 5, 8), 2), ("b", "c"): torch.full((6, 7), 2)},
[6],
),
]
padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask)
assert padded_td.shape == torch.Size(
[2, 6]
) # check the shape of the padded tensordict
assert padded_td["a"].shape == torch.Size(
[2, 6, 5, 8]
) # check the shape of the padded tensor
assert torch.all(
padded_td["a"][0, :, :3, :] == 1
) # check the values of the first tensor
assert torch.all(padded_td["a"][0, :, 3:, :] == 0) # check the padding
assert torch.all(
padded_td["a"][1, :, :, :] == 2
) # check the values of the second tensor
assert padded_td["b", "c"].shape == torch.Size(
[2, 6, 7]
) # check the shape of the padded tensor
assert torch.all(padded_td["b", "c"][0, :, 3:] == 0) # check the padding
if make_mask:
assert "mask" in padded_td.keys()
assert not padded_td["mask"].all()
padded_td_without_masks = pad_sequence(
list_td, pad_dim=pad_dim, return_mask=False
)
assert "masks" in padded_td.keys()
assert set(
padded_td_without_masks.keys(include_nested=True, leaves_only=True)
) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True))
assert not padded_td["masks", "a"].all()
assert padded_td["masks", "a"].ndim == pad_dim + 2
assert (padded_td["a"][padded_td["masks", "a"]] != 0).all()
assert (padded_td["a"][~padded_td["masks", "a"]] == 0).all()
assert not padded_td["masks", "b", "c"].all()
assert padded_td["masks", "b", "c"].ndim == pad_dim + 2
assert (padded_td["b", "c"][padded_td["masks", "b", "c"]] != 0).all()
assert (padded_td["b", "c"][~padded_td["masks", "b", "c"]] == 0).all()
else:
assert "mask" not in padded_td.keys()
assert "masks" not in padded_td.keys()

@pytest.mark.parametrize("device", get_available_devices())
def test_permute(self, device):
Expand Down

0 comments on commit b5f6c17

Please sign in to comment.