diff --git a/tensordict/functional.py b/tensordict/functional.py index 226e55da1..2a19ad47f 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -156,10 +156,11 @@ def pad_sequence( "plase convert the tensorclasses to TensorDicts first." ) - masks_key = "masks" if not isinstance(return_mask, bool): masks_key = unravel_key(return_mask) return_mask = True + else: + masks_key = "masks" # check that all tensordict match update_batch_size = True @@ -167,6 +168,7 @@ def pad_sequence( keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True) list_of_dicts = [{} for _ in range(len(list_of_tensordicts))] keys_copy = list(keys) + mask_keys = [] for i, td in enumerate(list_of_tensordicts): if is_tensorclass(td): td = td._tensordict @@ -197,6 +199,7 @@ def pad_sequence( if return_mask: mask_key = unravel_key((masks_key, key)) + mask_keys.append(mask_key) list_of_dicts[i][mask_key] = torch.ones(mask_shape, dtype=torch.bool) keys_copy.append(mask_key) @@ -229,7 +232,7 @@ def pad_sequence( torch.nn.utils.rnn.pad_sequence( [d[key].transpose(0, pos_pad_dim) for d in list_of_dicts], batch_first=True, - padding_value=padding_value, + padding_value=padding_value if key not in mask_keys else False, ).transpose(1, pos_pad_dim + 1), inplace=True, ) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 0fc698f59..ebf1f76bd 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1942,7 +1942,8 @@ class Sample: assert d.b == ["asd", "efg"] @pytest.mark.parametrize("make_mask", [True, ("bibbidi", "bobbidi", "boo"), False]) - def test_pad_sequence_pad_dim0(self, make_mask): + @pytest.mark.parametrize("pad_val", [0, -1]) + def test_pad_sequence_pad_dim0(self, make_mask, pad_val): pad_dim = 0 list_td = [ TensorDict( @@ -1953,7 +1954,9 @@ def test_pad_sequence_pad_dim0(self, make_mask): [4], ), ] - padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask) + padded_td = pad_sequence( + list_td, pad_dim=pad_dim, return_mask=make_mask, padding_value=pad_val + ) assert padded_td.shape == torch.Size( [2, 4] ) # check the shape of the padded tensordict @@ -1966,17 +1969,17 @@ def test_pad_sequence_pad_dim0(self, make_mask): assert padded_td["a"].shape == torch.Size( [2, 4, 8, 8] ) # check the shape of the padded tensor - assert torch.all(padded_td["a"][0, 2:, :, :] == 0) # check the padding + assert torch.all(padded_td["a"][0, 2:, :, :] == pad_val) # check the padding assert padded_td["b", "c"].shape == torch.Size( [2, 4, 3] ) # check the shape of the padded tensor - assert torch.all(padded_td["b", "c"][0, 2:, :] == 0) # check the padding + assert torch.all(padded_td["b", "c"][0, 2:, :] == pad_val) # check the padding if make_mask: masks_key = "masks" if not isinstance(make_mask, bool): masks_key = make_mask padded_td_without_masks = pad_sequence( - list_td, pad_dim=pad_dim, return_mask=False + list_td, pad_dim=pad_dim, return_mask=False, padding_value=pad_val ) assert masks_key in padded_td.keys(True) assert set( @@ -1984,12 +1987,16 @@ def test_pad_sequence_pad_dim0(self, make_mask): ) == set(padded_td[masks_key].keys(include_nested=True, leaves_only=True)) assert not padded_td[masks_key, "a"].all() assert padded_td[masks_key, "a"].ndim == pad_dim + 2 - assert (padded_td["a"][padded_td[masks_key, "a"]] != 0).all() - assert (padded_td["a"][~padded_td[masks_key, "a"]] == 0).all() + assert (padded_td["a"][padded_td[masks_key, "a"]] != pad_val).all() + assert (padded_td["a"][~padded_td[masks_key, "a"]] == pad_val).all() assert not padded_td[masks_key, "b", "c"].all() assert padded_td[masks_key, "b", "c"].ndim == pad_dim + 2 - assert (padded_td["b", "c"][padded_td[masks_key, "b", "c"]] != 0).all() - assert (padded_td["b", "c"][~padded_td[masks_key, "b", "c"]] == 0).all() + assert ( + padded_td["b", "c"][padded_td[masks_key, "b", "c"]] != pad_val + ).all() + assert ( + padded_td["b", "c"][~padded_td[masks_key, "b", "c"]] == pad_val + ).all() else: assert "masks" not in padded_td.keys()