Skip to content

Commit

Permalink
more checks 2dims
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Feb 4, 2024
1 parent 7fa567e commit da1a494
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_remove_duplicates_1dim(key, dim):
assert (output_tensordict == expected_output).all()


@pytest.mark.parametrize("dim", (0, 1, -1, -2))
@pytest.mark.parametrize("dim", (0, 1, 2, -1, -2, -3))
def test_remove_duplicates_2dims(dim):
key = "tensor1"
input_tensordict = TensorDict(
Expand All @@ -222,26 +222,32 @@ def test_remove_duplicates_2dims(dim):
batch_size=[4, 4],
)

output_tensordict = remove_duplicates(input_tensordict, key, dim)
if dim in (2, -3):
with pytest.raises(
ValueError,
match=f"The specified dimension '{dim}' is invalid for a TensorDict with batch size .*.",
):
remove_duplicates(input_tensordict, key, dim)

if dim in (0, -2):
expected_output = TensorDict(
{
"tensor1": torch.ones(1, 4),
"tensor2": torch.ones(1, 4),
},
batch_size=[1, 4],
)
else:
expected_output = TensorDict(
{
"tensor1": torch.ones(4, 1),
"tensor2": torch.ones(4, 1),
},
batch_size=[4, 1],
)

assert (output_tensordict == expected_output).all()
output_tensordict = remove_duplicates(input_tensordict, key, dim)
if dim in (0, -2):
expected_output = TensorDict(
{
"tensor1": torch.ones(1, 4),
"tensor2": torch.ones(1, 4),
},
batch_size=[1, 4],
)
else:
expected_output = TensorDict(
{
"tensor1": torch.ones(4, 1),
"tensor2": torch.ones(4, 1),
},
batch_size=[4, 1],
)
assert (output_tensordict == expected_output).all()


if __name__ == "__main__":
Expand Down

0 comments on commit da1a494

Please sign in to comment.