diff --git a/test/test_utils.py b/test/test_utils.py index 21f9bf529..d8a2518b0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -211,7 +211,7 @@ def test_remove_duplicates_1dim(key, dim): assert (output_tensordict == expected_output).all() -@pytest.mark.parametrize("dim", (0, 1, -1, -2)) +@pytest.mark.parametrize("dim", (0, 1, 2, -1, -2, -3)) def test_remove_duplicates_2dims(dim): key = "tensor1" input_tensordict = TensorDict( @@ -222,26 +222,32 @@ def test_remove_duplicates_2dims(dim): batch_size=[4, 4], ) - output_tensordict = remove_duplicates(input_tensordict, key, dim) + if dim in (2, -3): + with pytest.raises( + ValueError, + match=f"The specified dimension '{dim}' is invalid for a TensorDict with batch size .*.", + ): + remove_duplicates(input_tensordict, key, dim) - if dim in (0, -2): - expected_output = TensorDict( - { - "tensor1": torch.ones(1, 4), - "tensor2": torch.ones(1, 4), - }, - batch_size=[1, 4], - ) else: - expected_output = TensorDict( - { - "tensor1": torch.ones(4, 1), - "tensor2": torch.ones(4, 1), - }, - batch_size=[4, 1], - ) - - assert (output_tensordict == expected_output).all() + output_tensordict = remove_duplicates(input_tensordict, key, dim) + if dim in (0, -2): + expected_output = TensorDict( + { + "tensor1": torch.ones(1, 4), + "tensor2": torch.ones(1, 4), + }, + batch_size=[1, 4], + ) + else: + expected_output = TensorDict( + { + "tensor1": torch.ones(4, 1), + "tensor2": torch.ones(4, 1), + }, + batch_size=[4, 1], + ) + assert (output_tensordict == expected_output).all() if __name__ == "__main__":