Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 25, 2024
1 parent f47ac14 commit ad15fd1
Showing 1 changed file with 95 additions and 81 deletions.
176 changes: 95 additions & 81 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,89 +996,103 @@ def test_pad(self):
assert torch.equal(padded_td["a"], expected_a)
padded_td._check_batch_size()

@pytest.mark.parametrize("pad_dim", [0, 1])
@pytest.mark.parametrize("make_mask", [True, False])
def test_pad_sequence(self, pad_dim, make_mask):
if pad_dim == 0:
list_td = [
TensorDict(
{"a": torch.ones((2, 8, 8)), ("b", "c"): torch.ones((2, 3))}, [2]
),
TensorDict(
{"a": torch.full((4, 8, 8), 2), ("b", "c"): torch.full((4, 3), 2)},
[4],
),
]
padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask)
assert padded_td.shape == torch.Size(
[2, 4]
) # check the shape of the padded tensordict
assert torch.all(
padded_td["a"][0, :2, :, :] == 1
) # check the values of the first tensor
assert torch.all(
padded_td["a"][1, :, :, :] == 2
) # check the values of the second tensor
assert padded_td["a"].shape == torch.Size(
[2, 4, 8, 8]
) # check the shape of the padded tensor
assert torch.all(padded_td["a"][0, 2:, :, :] == 0) # check the padding
assert padded_td["b", "c"].shape == torch.Size(
[2, 4, 3]
) # check the shape of the padded tensor
assert torch.all(padded_td["b", "c"][0, 2:, :] == 0) # check the padding
if make_mask:
padded_td_without_masks = pad_sequence(
list_td, pad_dim=pad_dim, return_mask=False
)
assert "masks" in padded_td.keys()
assert set(
padded_td_without_masks.keys(include_nested=True, leaves_only=True)
) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True))
assert not padded_td["masks", "a"].all()
assert not padded_td["masks", "b", "c"].all()
else:
assert "masks" not in padded_td.keys()
def test_pad_sequence_pad_dim0(self, make_mask):
pad_dim = 0
list_td = [
TensorDict(
{"a": torch.ones((2, 8, 8)), ("b", "c"): torch.ones((2, 3))}, [2]
),
TensorDict(
{"a": torch.full((4, 8, 8), 2), ("b", "c"): torch.full((4, 3), 2)},
[4],
),
]
padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask)
assert padded_td.shape == torch.Size(
[2, 4]
) # check the shape of the padded tensordict
assert torch.all(
padded_td["a"][0, :2, :, :] == 1
) # check the values of the first tensor
assert torch.all(
padded_td["a"][1, :, :, :] == 2
) # check the values of the second tensor
assert padded_td["a"].shape == torch.Size(
[2, 4, 8, 8]
) # check the shape of the padded tensor
assert torch.all(padded_td["a"][0, 2:, :, :] == 0) # check the padding
assert padded_td["b", "c"].shape == torch.Size(
[2, 4, 3]
) # check the shape of the padded tensor
assert torch.all(padded_td["b", "c"][0, 2:, :] == 0) # check the padding
if make_mask:
padded_td_without_masks = pad_sequence(
list_td, pad_dim=pad_dim, return_mask=False
)
assert "masks" in padded_td.keys()
assert set(
padded_td_without_masks.keys(include_nested=True, leaves_only=True)
) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True))
assert not padded_td["masks", "a"].all()
assert padded_td["masks", "a"].ndim == pad_dim + 2
assert (padded_td["a"][padded_td["masks", "a"]] != 0).all()
assert (padded_td["a"][~padded_td["masks", "a"]] == 0).all()
assert not padded_td["masks", "b", "c"].all()
assert padded_td["masks", "b", "c"].ndim == pad_dim + 2
assert (padded_td["b", "c"][padded_td["masks", "b", "c"]] != 0).all()
assert (padded_td["b", "c"][~padded_td["masks", "b", "c"]] == 0).all()
else:
list_td = [
TensorDict(
{"a": torch.ones((6, 3, 8)), ("b", "c"): torch.ones((6, 3))}, []
),
TensorDict(
{"a": torch.full((6, 5, 8), 2), ("b", "c"): torch.full((6, 7), 2)},
[],
),
]
padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask)
assert padded_td.shape == torch.Size(
[2]
) # check the shape of the padded tensordict
assert padded_td["a"].shape == torch.Size(
[2, 6, 5, 8]
) # check the shape of the padded tensor
assert torch.all(
padded_td["a"][0, :, :3, :] == 1
) # check the values of the first tensor
assert torch.all(padded_td["a"][0, :, 3:, :] == 0) # check the padding
assert torch.all(
padded_td["a"][1, :, :, :] == 2
) # check the values of the second tensor
assert padded_td["b", "c"].shape == torch.Size(
[2, 6, 7]
) # check the shape of the padded tensor
assert torch.all(padded_td["b", "c"][0, :, 3:] == 0) # check the padding
if make_mask:
padded_td_without_masks = pad_sequence(
list_td, pad_dim=pad_dim, return_mask=False
)
assert "masks" in padded_td.keys()
assert set(
padded_td_without_masks.keys(include_nested=True, leaves_only=True)
) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True))
assert not padded_td["masks", "a"].all()
assert not padded_td["masks", "b", "c"].all()
else:
assert "masks" not in padded_td.keys()
assert "masks" not in padded_td.keys()

@pytest.mark.parametrize("make_mask", [True, False])
def test_pad_sequence_pad_dim1(self, make_mask):
pad_dim = 1
list_td = [
TensorDict(
{"a": torch.ones((6, 3, 8)), ("b", "c"): torch.ones((6, 3))}, [6]
),
TensorDict(
{"a": torch.full((6, 5, 8), 2), ("b", "c"): torch.full((6, 7), 2)},
[6],
),
]
padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask)
assert padded_td.shape == torch.Size(
[2, 6]
) # check the shape of the padded tensordict
assert padded_td["a"].shape == torch.Size(
[2, 6, 5, 8]
) # check the shape of the padded tensor
assert torch.all(
padded_td["a"][0, :, :3, :] == 1
) # check the values of the first tensor
assert torch.all(padded_td["a"][0, :, 3:, :] == 0) # check the padding
assert torch.all(
padded_td["a"][1, :, :, :] == 2
) # check the values of the second tensor
assert padded_td["b", "c"].shape == torch.Size(
[2, 6, 7]
) # check the shape of the padded tensor
assert torch.all(padded_td["b", "c"][0, :, 3:] == 0) # check the padding
if make_mask:
padded_td_without_masks = pad_sequence(
list_td, pad_dim=pad_dim, return_mask=False
)
assert "masks" in padded_td.keys()
assert set(
padded_td_without_masks.keys(include_nested=True, leaves_only=True)
) == set(padded_td["masks"].keys(include_nested=True, leaves_only=True))
assert not padded_td["masks", "a"].all()
assert padded_td["masks", "a"].ndim == pad_dim + 2
assert (padded_td["a"][padded_td["masks", "a"]] != 0).all()
assert (padded_td["a"][~padded_td["masks", "a"]] == 0).all()
assert not padded_td["masks", "b", "c"].all()
assert padded_td["masks", "b", "c"].ndim == pad_dim + 2
assert (padded_td["b", "c"][padded_td["masks", "b", "c"]] != 0).all()
assert (padded_td["b", "c"][~padded_td["masks", "b", "c"]] == 0).all()
else:
assert "masks" not in padded_td.keys()

@pytest.mark.parametrize("device", get_available_devices())
def test_permute(self, device):
Expand Down

0 comments on commit ad15fd1

Please sign in to comment.