Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Proper masks for padding with custom pad value #1185

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tensordict/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,19 @@ def pad_sequence(
"plase convert the tensorclasses to TensorDicts first."
)

masks_key = "masks"
if not isinstance(return_mask, bool):
masks_key = unravel_key(return_mask)
return_mask = True
else:
masks_key = "masks"

# check that all tensordict match
update_batch_size = True
max_seq_length = float("-inf")
keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True)
list_of_dicts = [{} for _ in range(len(list_of_tensordicts))]
keys_copy = list(keys)
mask_keys = []
for i, td in enumerate(list_of_tensordicts):
if is_tensorclass(td):
td = td._tensordict
Expand Down Expand Up @@ -197,6 +199,7 @@ def pad_sequence(

if return_mask:
mask_key = unravel_key((masks_key, key))
mask_keys.append(mask_key)
list_of_dicts[i][mask_key] = torch.ones(mask_shape, dtype=torch.bool)
keys_copy.append(mask_key)

Expand Down Expand Up @@ -229,7 +232,7 @@ def pad_sequence(
torch.nn.utils.rnn.pad_sequence(
[d[key].transpose(0, pos_pad_dim) for d in list_of_dicts],
batch_first=True,
padding_value=padding_value,
padding_value=padding_value if key not in mask_keys else False,
).transpose(1, pos_pad_dim + 1),
inplace=True,
)
Expand Down
25 changes: 16 additions & 9 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1942,7 +1942,8 @@ class Sample:
assert d.b == ["asd", "efg"]

@pytest.mark.parametrize("make_mask", [True, ("bibbidi", "bobbidi", "boo"), False])
def test_pad_sequence_pad_dim0(self, make_mask):
@pytest.mark.parametrize("pad_val", [0, -1])
def test_pad_sequence_pad_dim0(self, make_mask, pad_val):
pad_dim = 0
list_td = [
TensorDict(
Expand All @@ -1953,7 +1954,9 @@ def test_pad_sequence_pad_dim0(self, make_mask):
[4],
),
]
padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask)
padded_td = pad_sequence(
list_td, pad_dim=pad_dim, return_mask=make_mask, padding_value=pad_val
)
assert padded_td.shape == torch.Size(
[2, 4]
) # check the shape of the padded tensordict
Expand All @@ -1966,30 +1969,34 @@ def test_pad_sequence_pad_dim0(self, make_mask):
assert padded_td["a"].shape == torch.Size(
[2, 4, 8, 8]
) # check the shape of the padded tensor
assert torch.all(padded_td["a"][0, 2:, :, :] == 0) # check the padding
assert torch.all(padded_td["a"][0, 2:, :, :] == pad_val) # check the padding
assert padded_td["b", "c"].shape == torch.Size(
[2, 4, 3]
) # check the shape of the padded tensor
assert torch.all(padded_td["b", "c"][0, 2:, :] == 0) # check the padding
assert torch.all(padded_td["b", "c"][0, 2:, :] == pad_val) # check the padding
if make_mask:
masks_key = "masks"
if not isinstance(make_mask, bool):
masks_key = make_mask
padded_td_without_masks = pad_sequence(
list_td, pad_dim=pad_dim, return_mask=False
list_td, pad_dim=pad_dim, return_mask=False, padding_value=pad_val
)
assert masks_key in padded_td.keys(True)
assert set(
padded_td_without_masks.keys(include_nested=True, leaves_only=True)
) == set(padded_td[masks_key].keys(include_nested=True, leaves_only=True))
assert not padded_td[masks_key, "a"].all()
assert padded_td[masks_key, "a"].ndim == pad_dim + 2
assert (padded_td["a"][padded_td[masks_key, "a"]] != 0).all()
assert (padded_td["a"][~padded_td[masks_key, "a"]] == 0).all()
assert (padded_td["a"][padded_td[masks_key, "a"]] != pad_val).all()
assert (padded_td["a"][~padded_td[masks_key, "a"]] == pad_val).all()
assert not padded_td[masks_key, "b", "c"].all()
assert padded_td[masks_key, "b", "c"].ndim == pad_dim + 2
assert (padded_td["b", "c"][padded_td[masks_key, "b", "c"]] != 0).all()
assert (padded_td["b", "c"][~padded_td[masks_key, "b", "c"]] == 0).all()
assert (
padded_td["b", "c"][padded_td[masks_key, "b", "c"]] != pad_val
).all()
assert (
padded_td["b", "c"][~padded_td[masks_key, "b", "c"]] == pad_val
).all()
else:
assert "masks" not in padded_td.keys()

Expand Down
Loading