diff --git a/tensordict/functional.py b/tensordict/functional.py index 1cbf168b7..cfb7bff4e 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + from typing import Sequence import torch @@ -77,7 +79,8 @@ def pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0) -> T: def pad_sequence( list_of_tensordicts: Sequence[T], - batch_first: bool = True, + batch_first: bool | None = None, + pad_dim: int = 0, padding_value: float = 0.0, out: T | None = None, device: DeviceType | None = None, @@ -87,76 +90,117 @@ def pad_sequence( Args: list_of_tensordicts (List[TensorDictBase]): the list of instances to pad and stack. - batch_first (bool, optional): the ``batch_first`` correspondant of :func:`torch.nn.utils.rnn.pad_sequence`. - Defaults to ``True``. + pad_dim (int, optional): the ``pad_dim`` indicates the dimension to pad all the keys in the tensordict. + Defaults to ``0``. padding_value (number, optional): the padding value. Defaults to ``0.0``. out (TensorDictBase, optional): if provided, the destination where the data will be written. device (device compatible type, optional): if provded, the device where the TensorDict output will be created. - return_mask (bool, optional): if ``True``, a "mask" entry will be returned. - It contains the mask of valid values in the stacked tensordict. + return_mask (bool, optional): if ``True``, a "masks" entry will be returned. + It contains a tensordict with the same structure as the stacked tensordict where every entry contains the mask of valid values with size ``torch.Size([stack_len, *new_shape])``, + where `new_shape[pad_dim] = max_seq_length` and the rest of the `new_shape` matches the previous shape of the contained tensors. Examples: >>> list_td = [ - ... TensorDict({"a": torch.zeros((3,))}, []), - ... TensorDict({"a": torch.zeros((4,))}, []), + ... TensorDict({"a": torch.zeros((3, 8)), "b": torch.zeros((6, 8))}, batch_size=[]), + ... TensorDict({"a": torch.zeros((5, 8)), "b": torch.zeros((6, 8))}, batch_size=[]), ... ] - >>> padded_td = pad_sequence(list_td) + >>> padded_td = pad_sequence(list_td, return_mask=True) >>> print(padded_td) TensorDict( fields={ - a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([]), + a: Tensor(shape=torch.Size([2, 4, 8]), device=cpu, dtype=torch.float32, is_shared=False), + b: Tensor(shape=torch.Size([2, 5, 8]), device=cpu, dtype=torch.float32, is_shared=False), + masks: TensorDict( + fields={ + a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.bool, is_shared=False), + b: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False)}, + batch_size=torch.Size([2]), device=None, is_shared=False) """ + if batch_first is not None: + warnings.warn( + "The batch_first argument is deprecated and will be removed in a future release. The output will always be batch_first.", + category=DeprecationWarning, + ) + if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") + # check that all tensordict match - if return_mask: - list_of_tensordicts = [ - td.clone(False).set("mask", torch.ones(td.shape, dtype=torch.bool)) - for td in list_of_tensordicts - ] + update_batch_size = True + max_seq_length = float("-inf") keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True) - shape = max(len(td) for td in list_of_tensordicts) - if shape == 0: - shape = [ - len(list_of_tensordicts), - ] - elif batch_first: - shape = [len(list_of_tensordicts), shape] - else: - shape = [shape, len(list_of_tensordicts)] - if out is None: - out = TensorDict( - {}, batch_size=torch.Size(shape), device=device, _run_checks=False - ) + tmp_list_of_tensordicts = [] + for td in list_of_tensordicts: + + if return_mask: + tmp_list_of_tensordicts.append(td.clone(False)) + for key in keys: - try: - out.set( - key, - torch.nn.utils.rnn.pad_sequence( - [td.get(key) for td in list_of_tensordicts], - batch_first=batch_first, - padding_value=padding_value, - ), + tensor_shape = td.get(key).shape + pos_pad_dim = pad_dim if pad_dim >= 0 else len(tensor_shape) + pad_dim + + # track the maximum sequence length to update batch_size accordingly + if tensor_shape[pos_pad_dim] > max_seq_length: + max_seq_length = tensor_shape[pos_pad_dim] + + # The mask should always contain the batch_size of the TensorDict + mask_shape = td.shape + + # if the pad_dim is past the batch_size of the TensorDict, we need to add the new dimension to the mask + if pos_pad_dim >= td.ndim: + mask_shape += torch.Size([tensor_shape[pos_pad_dim]]) + update_batch_size = False + + if return_mask: + tmp_list_of_tensordicts[-1].set( + ("masks", key), + torch.ones(mask_shape, dtype=torch.bool), ) - except Exception as err: - raise RuntimeError(f"pad_sequence failed for key {key}") from err - return out - else: - for key in keys: - out.set_( + if return_mask: + list_of_tensordicts = tmp_list_of_tensordicts + + keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True) + + old_batch_size = list(list_of_tensordicts[0].batch_size) + if update_batch_size and len(old_batch_size) > 0: + old_batch_size[pad_dim] = max_seq_length + shape = [ + len(list_of_tensordicts), + ] + old_batch_size + + if out is None: + out = list_of_tensordicts[0].empty(recurse=True).reshape(torch.Size(shape)) + + for key in keys: + try: + tensor_shape = list_of_tensordicts[0].get(key).shape + pos_pad_dim = ( + (pad_dim if pad_dim >= 0 else len(tensor_shape) + pad_dim) + if len(tensor_shape) > 1 + else 0 # handles the case when the masks are 1-dimensional + ) + out.set( key, torch.nn.utils.rnn.pad_sequence( - [td.get(key) for td in list_of_tensordicts], - batch_first=batch_first, + [ + td.get(key).transpose(0, pos_pad_dim) + for td in list_of_tensordicts + ], + batch_first=True, padding_value=padding_value, - ), + ).transpose(1, pos_pad_dim + 1), + inplace=True, ) - return out + except Exception as err: + raise RuntimeError(f"pad_sequence failed for key {key}") from err + return out def merge_tensordicts(*tensordicts: T) -> T: diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 7c40eb5ca..819556a6f 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -996,33 +996,103 @@ def test_pad(self): assert torch.equal(padded_td["a"], expected_a) padded_td._check_batch_size() - @pytest.mark.parametrize("batch_first", [True, False]) @pytest.mark.parametrize("make_mask", [True, False]) - def test_pad_sequence(self, batch_first, make_mask): + def test_pad_sequence_pad_dim0(self, make_mask): + pad_dim = 0 list_td = [ - TensorDict({"a": torch.ones((2,)), ("b", "c"): torch.ones((2, 3))}, [2]), - TensorDict({"a": torch.ones((4,)), ("b", "c"): torch.ones((4, 3))}, [4]), + TensorDict( + {"a": torch.ones((2, 8, 8)), ("b", "c"): torch.ones((2, 3))}, [2] + ), + TensorDict( + {"a": torch.full((4, 8, 8), 2), ("b", "c"): torch.full((4, 3), 2)}, + [4], + ), ] - padded_td = pad_sequence( - list_td, batch_first=batch_first, return_mask=make_mask - ) - if batch_first: - assert padded_td.shape == torch.Size([2, 4]) - assert padded_td["a"].shape == torch.Size([2, 4]) - assert padded_td["a"][0, -1] == 0 - assert padded_td["b", "c"].shape == torch.Size([2, 4, 3]) - assert padded_td["b", "c"][0, -1, 0] == 0 + padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask) + assert padded_td.shape == torch.Size( + [2, 4] + ) # check the shape of the padded tensordict + assert torch.all( + padded_td["a"][0, :2, :, :] == 1 + ) # check the values of the first tensor + assert torch.all( + padded_td["a"][1, :, :, :] == 2 + ) # check the values of the second tensor + assert padded_td["a"].shape == torch.Size( + [2, 4, 8, 8] + ) # check the shape of the padded tensor + assert torch.all(padded_td["a"][0, 2:, :, :] == 0) # check the padding + assert padded_td["b", "c"].shape == torch.Size( + [2, 4, 3] + ) # check the shape of the padded tensor + assert torch.all(padded_td["b", "c"][0, 2:, :] == 0) # check the padding + if make_mask: + padded_td_without_masks = pad_sequence( + list_td, pad_dim=pad_dim, return_mask=False + ) + assert "masks" in padded_td.keys() + assert set( + padded_td_without_masks.keys(include_nested=True, leaves_only=True) + ) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True)) + assert not padded_td["masks", "a"].all() + assert padded_td["masks", "a"].ndim == pad_dim + 2 + assert (padded_td["a"][padded_td["masks", "a"]] != 0).all() + assert (padded_td["a"][~padded_td["masks", "a"]] == 0).all() + assert not padded_td["masks", "b", "c"].all() + assert padded_td["masks", "b", "c"].ndim == pad_dim + 2 + assert (padded_td["b", "c"][padded_td["masks", "b", "c"]] != 0).all() + assert (padded_td["b", "c"][~padded_td["masks", "b", "c"]] == 0).all() else: - assert padded_td.shape == torch.Size([4, 2]) - assert padded_td["a"].shape == torch.Size([4, 2]) - assert padded_td["a"][-1, 0] == 0 - assert padded_td["b", "c"].shape == torch.Size([4, 2, 3]) - assert padded_td["b", "c"][-1, 0, 0] == 0 + assert "masks" not in padded_td.keys() + + @pytest.mark.parametrize("make_mask", [True, False]) + def test_pad_sequence_pad_dim1(self, make_mask): + pad_dim = 1 + list_td = [ + TensorDict( + {"a": torch.ones((6, 3, 8)), ("b", "c"): torch.ones((6, 3))}, [6] + ), + TensorDict( + {"a": torch.full((6, 5, 8), 2), ("b", "c"): torch.full((6, 7), 2)}, + [6], + ), + ] + padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask) + assert padded_td.shape == torch.Size( + [2, 6] + ) # check the shape of the padded tensordict + assert padded_td["a"].shape == torch.Size( + [2, 6, 5, 8] + ) # check the shape of the padded tensor + assert torch.all( + padded_td["a"][0, :, :3, :] == 1 + ) # check the values of the first tensor + assert torch.all(padded_td["a"][0, :, 3:, :] == 0) # check the padding + assert torch.all( + padded_td["a"][1, :, :, :] == 2 + ) # check the values of the second tensor + assert padded_td["b", "c"].shape == torch.Size( + [2, 6, 7] + ) # check the shape of the padded tensor + assert torch.all(padded_td["b", "c"][0, :, 3:] == 0) # check the padding if make_mask: - assert "mask" in padded_td.keys() - assert not padded_td["mask"].all() + padded_td_without_masks = pad_sequence( + list_td, pad_dim=pad_dim, return_mask=False + ) + assert "masks" in padded_td.keys() + assert set( + padded_td_without_masks.keys(include_nested=True, leaves_only=True) + ) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True)) + assert not padded_td["masks", "a"].all() + assert padded_td["masks", "a"].ndim == pad_dim + 2 + assert (padded_td["a"][padded_td["masks", "a"]] != 0).all() + assert (padded_td["a"][~padded_td["masks", "a"]] == 0).all() + assert not padded_td["masks", "b", "c"].all() + assert padded_td["masks", "b", "c"].ndim == pad_dim + 2 + assert (padded_td["b", "c"][padded_td["masks", "b", "c"]] != 0).all() + assert (padded_td["b", "c"][~padded_td["masks", "b", "c"]] == 0).all() else: - assert "mask" not in padded_td.keys() + assert "masks" not in padded_td.keys() @pytest.mark.parametrize("device", get_available_devices()) def test_permute(self, device):