diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 0501042be..819556a6f 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -996,89 +996,103 @@ def test_pad(self): assert torch.equal(padded_td["a"], expected_a) padded_td._check_batch_size() - @pytest.mark.parametrize("pad_dim", [0, 1]) @pytest.mark.parametrize("make_mask", [True, False]) - def test_pad_sequence(self, pad_dim, make_mask): - if pad_dim == 0: - list_td = [ - TensorDict( - {"a": torch.ones((2, 8, 8)), ("b", "c"): torch.ones((2, 3))}, [2] - ), - TensorDict( - {"a": torch.full((4, 8, 8), 2), ("b", "c"): torch.full((4, 3), 2)}, - [4], - ), - ] - padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask) - assert padded_td.shape == torch.Size( - [2, 4] - ) # check the shape of the padded tensordict - assert torch.all( - padded_td["a"][0, :2, :, :] == 1 - ) # check the values of the first tensor - assert torch.all( - padded_td["a"][1, :, :, :] == 2 - ) # check the values of the second tensor - assert padded_td["a"].shape == torch.Size( - [2, 4, 8, 8] - ) # check the shape of the padded tensor - assert torch.all(padded_td["a"][0, 2:, :, :] == 0) # check the padding - assert padded_td["b", "c"].shape == torch.Size( - [2, 4, 3] - ) # check the shape of the padded tensor - assert torch.all(padded_td["b", "c"][0, 2:, :] == 0) # check the padding - if make_mask: - padded_td_without_masks = pad_sequence( - list_td, pad_dim=pad_dim, return_mask=False - ) - assert "masks" in padded_td.keys() - assert set( - padded_td_without_masks.keys(include_nested=True, leaves_only=True) - ) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True)) - assert not padded_td["masks", "a"].all() - assert not padded_td["masks", "b", "c"].all() - else: - assert "masks" not in padded_td.keys() + def test_pad_sequence_pad_dim0(self, make_mask): + pad_dim = 0 + list_td = [ + TensorDict( + {"a": torch.ones((2, 8, 8)), ("b", "c"): torch.ones((2, 3))}, [2] + ), + TensorDict( + {"a": torch.full((4, 8, 8), 2), ("b", "c"): torch.full((4, 3), 2)}, + [4], + ), + ] + padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask) + assert padded_td.shape == torch.Size( + [2, 4] + ) # check the shape of the padded tensordict + assert torch.all( + padded_td["a"][0, :2, :, :] == 1 + ) # check the values of the first tensor + assert torch.all( + padded_td["a"][1, :, :, :] == 2 + ) # check the values of the second tensor + assert padded_td["a"].shape == torch.Size( + [2, 4, 8, 8] + ) # check the shape of the padded tensor + assert torch.all(padded_td["a"][0, 2:, :, :] == 0) # check the padding + assert padded_td["b", "c"].shape == torch.Size( + [2, 4, 3] + ) # check the shape of the padded tensor + assert torch.all(padded_td["b", "c"][0, 2:, :] == 0) # check the padding + if make_mask: + padded_td_without_masks = pad_sequence( + list_td, pad_dim=pad_dim, return_mask=False + ) + assert "masks" in padded_td.keys() + assert set( + padded_td_without_masks.keys(include_nested=True, leaves_only=True) + ) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True)) + assert not padded_td["masks", "a"].all() + assert padded_td["masks", "a"].ndim == pad_dim + 2 + assert (padded_td["a"][padded_td["masks", "a"]] != 0).all() + assert (padded_td["a"][~padded_td["masks", "a"]] == 0).all() + assert not padded_td["masks", "b", "c"].all() + assert padded_td["masks", "b", "c"].ndim == pad_dim + 2 + assert (padded_td["b", "c"][padded_td["masks", "b", "c"]] != 0).all() + assert (padded_td["b", "c"][~padded_td["masks", "b", "c"]] == 0).all() else: - list_td = [ - TensorDict( - {"a": torch.ones((6, 3, 8)), ("b", "c"): torch.ones((6, 3))}, [] - ), - TensorDict( - {"a": torch.full((6, 5, 8), 2), ("b", "c"): torch.full((6, 7), 2)}, - [], - ), - ] - padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask) - assert padded_td.shape == torch.Size( - [2] - ) # check the shape of the padded tensordict - assert padded_td["a"].shape == torch.Size( - [2, 6, 5, 8] - ) # check the shape of the padded tensor - assert torch.all( - padded_td["a"][0, :, :3, :] == 1 - ) # check the values of the first tensor - assert torch.all(padded_td["a"][0, :, 3:, :] == 0) # check the padding - assert torch.all( - padded_td["a"][1, :, :, :] == 2 - ) # check the values of the second tensor - assert padded_td["b", "c"].shape == torch.Size( - [2, 6, 7] - ) # check the shape of the padded tensor - assert torch.all(padded_td["b", "c"][0, :, 3:] == 0) # check the padding - if make_mask: - padded_td_without_masks = pad_sequence( - list_td, pad_dim=pad_dim, return_mask=False - ) - assert "masks" in padded_td.keys() - assert set( - padded_td_without_masks.keys(include_nested=True, leaves_only=True) - ) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True)) - assert not padded_td["masks", "a"].all() - assert not padded_td["masks", "b", "c"].all() - else: - assert "masks" not in padded_td.keys() + assert "masks" not in padded_td.keys() + + @pytest.mark.parametrize("make_mask", [True, False]) + def test_pad_sequence_pad_dim1(self, make_mask): + pad_dim = 1 + list_td = [ + TensorDict( + {"a": torch.ones((6, 3, 8)), ("b", "c"): torch.ones((6, 3))}, [6] + ), + TensorDict( + {"a": torch.full((6, 5, 8), 2), ("b", "c"): torch.full((6, 7), 2)}, + [6], + ), + ] + padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask) + assert padded_td.shape == torch.Size( + [2, 6] + ) # check the shape of the padded tensordict + assert padded_td["a"].shape == torch.Size( + [2, 6, 5, 8] + ) # check the shape of the padded tensor + assert torch.all( + padded_td["a"][0, :, :3, :] == 1 + ) # check the values of the first tensor + assert torch.all(padded_td["a"][0, :, 3:, :] == 0) # check the padding + assert torch.all( + padded_td["a"][1, :, :, :] == 2 + ) # check the values of the second tensor + assert padded_td["b", "c"].shape == torch.Size( + [2, 6, 7] + ) # check the shape of the padded tensor + assert torch.all(padded_td["b", "c"][0, :, 3:] == 0) # check the padding + if make_mask: + padded_td_without_masks = pad_sequence( + list_td, pad_dim=pad_dim, return_mask=False + ) + assert "masks" in padded_td.keys() + assert set( + padded_td_without_masks.keys(include_nested=True, leaves_only=True) + ) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True)) + assert not padded_td["masks", "a"].all() + assert padded_td["masks", "a"].ndim == pad_dim + 2 + assert (padded_td["a"][padded_td["masks", "a"]] != 0).all() + assert (padded_td["a"][~padded_td["masks", "a"]] == 0).all() + assert not padded_td["masks", "b", "c"].all() + assert padded_td["masks", "b", "c"].ndim == pad_dim + 2 + assert (padded_td["b", "c"][padded_td["masks", "b", "c"]] != 0).all() + assert (padded_td["b", "c"][~padded_td["masks", "b", "c"]] == 0).all() + else: + assert "masks" not in padded_td.keys() @pytest.mark.parametrize("device", get_available_devices()) def test_permute(self, device):