diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 5ab35fc873b..aa11982a2f3 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -116,11 +116,9 @@ class TestSmoke: (transforms.RandAugment(), auto_augment_adapter), (transforms.TrivialAugmentWide(), auto_augment_adapter), (transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None), - (transforms.Grayscale(), None), (transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None), (transforms.RandomAutocontrast(p=1.0), None), (transforms.RandomEqualize(p=1.0), None), - (transforms.RandomGrayscale(p=1.0), None), (transforms.RandomInvert(p=1.0), None), (transforms.RandomChannelPermutation(), None), (transforms.RandomPhotometricDistort(p=1.0), None), diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index efeb673059f..26ef3121809 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -122,17 +122,6 @@ def __init__( (torch.float32, torch.float64), ] ], - ConsistencyConfig( - v2_transforms.Grayscale, - legacy_transforms.Grayscale, - [ - ArgsKwargs(num_output_channels=1), - ArgsKwargs(num_output_channels=3), - ], - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]), - # Use default tolerances of `torch.testing.assert_close` - closeness_kwargs=dict(rtol=None, atol=None), - ), ConsistencyConfig( v2_transforms.ToPILImage, legacy_transforms.ToPILImage, @@ -217,17 +206,6 @@ def __init__( ], closeness_kwargs={"atol": 1e-6, "rtol": 1e-6}, ), - ConsistencyConfig( - v2_transforms.RandomGrayscale, - legacy_transforms.RandomGrayscale, - [ - ArgsKwargs(p=0), - ArgsKwargs(p=1), - ], - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]), - # Use default tolerances of `torch.testing.assert_close` - closeness_kwargs=dict(rtol=None, atol=None), - ), ConsistencyConfig( v2_transforms.PILToTensor, legacy_transforms.PILToTensor, diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 59d30d482e2..55423d359dd 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -3945,3 +3945,58 @@ def test_transform_correctness(self, brightness, contrast, saturation, hue): mae = (actual.float() - expected.float()).abs().mean() assert mae < 2 + + +class TestRgbToGrayscale: + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_image(self, dtype, device): + check_kernel(F.rgb_to_grayscale_image, make_image(dtype=dtype, device=device)) + + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image]) + def test_functional(self, make_input): + check_functional(F.rgb_to_grayscale, make_input()) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.rgb_to_grayscale_image, torch.Tensor), + (F._rgb_to_grayscale_image_pil, PIL.Image.Image), + (F.rgb_to_grayscale_image, tv_tensors.Image), + ], + ) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.rgb_to_grayscale, kernel=kernel, input_type=input_type) + + @pytest.mark.parametrize("transform", [transforms.Grayscale(), transforms.RandomGrayscale(p=1)]) + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image]) + def test_transform(self, transform, make_input): + check_transform(transform, make_input()) + + @pytest.mark.parametrize("num_output_channels", [1, 3]) + @pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)]) + def test_image_correctness(self, num_output_channels, fn): + image = make_image(dtype=torch.uint8, device="cpu") + + actual = fn(image, num_output_channels=num_output_channels) + expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels)) + + assert_equal(actual, expected, rtol=0, atol=1) + + @pytest.mark.parametrize("num_input_channels", [1, 3]) + def test_random_transform_correctness(self, num_input_channels): + image = make_image( + color_space={ + 1: "GRAY", + 3: "RGB", + }[num_input_channels], + dtype=torch.uint8, + device="cpu", + ) + + transform = transforms.RandomGrayscale(p=1) + + actual = transform(image) + expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_input_channels)) + + assert_equal(actual, expected, rtol=0, atol=1)