diff --git a/tests/link/pytorch/test_sort.py b/tests/link/pytorch/test_sort.py index 8595c43303..386a974cf4 100644 --- a/tests/link/pytorch/test_sort.py +++ b/tests/link/pytorch/test_sort.py @@ -7,9 +7,17 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py -@pytest.mark.xfail(reason="Reshape not implemented") -@pytest.mark.parametrize("axis", [0, 1, None]) @pytest.mark.parametrize("func", (sort, argsort)) +@pytest.mark.parametrize( + "axis", + [ + pytest.param(0), + pytest.param(1), + pytest.param( + None, marks=pytest.mark.xfail(reason="Reshape Op not implemented") + ), + ], +) def test_sort(func, axis): x = matrix("x", shape=(2, 2), dtype="float64") out = func(x, axis=axis)