From da1a494ce4aa0776c5d29df2e82b15c0e6da138b Mon Sep 17 00:00:00 2001 From: albert bou Date: Sun, 4 Feb 2024 17:08:55 +0100 Subject: [PATCH] more checks 2dims --- test/test_utils.py | 44 +++++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 21f9bf529..d8a2518b0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -211,7 +211,7 @@ def test_remove_duplicates_1dim(key, dim): assert (output_tensordict == expected_output).all() -@pytest.mark.parametrize("dim", (0, 1, -1, -2)) +@pytest.mark.parametrize("dim", (0, 1, 2, -1, -2, -3)) def test_remove_duplicates_2dims(dim): key = "tensor1" input_tensordict = TensorDict( @@ -222,26 +222,32 @@ def test_remove_duplicates_2dims(dim): batch_size=[4, 4], ) - output_tensordict = remove_duplicates(input_tensordict, key, dim) + if dim in (2, -3): + with pytest.raises( + ValueError, + match=f"The specified dimension '{dim}' is invalid for a TensorDict with batch size .*.", + ): + remove_duplicates(input_tensordict, key, dim) - if dim in (0, -2): - expected_output = TensorDict( - { - "tensor1": torch.ones(1, 4), - "tensor2": torch.ones(1, 4), - }, - batch_size=[1, 4], - ) else: - expected_output = TensorDict( - { - "tensor1": torch.ones(4, 1), - "tensor2": torch.ones(4, 1), - }, - batch_size=[4, 1], - ) - - assert (output_tensordict == expected_output).all() + output_tensordict = remove_duplicates(input_tensordict, key, dim) + if dim in (0, -2): + expected_output = TensorDict( + { + "tensor1": torch.ones(1, 4), + "tensor2": torch.ones(1, 4), + }, + batch_size=[1, 4], + ) + else: + expected_output = TensorDict( + { + "tensor1": torch.ones(4, 1), + "tensor2": torch.ones(4, 1), + }, + batch_size=[4, 1], + ) + assert (output_tensordict == expected_output).all() if __name__ == "__main__":