diff --git a/test/test_image.py b/test/test_image.py index d489b10af7c..4d14af638a0 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -42,6 +42,19 @@ IS_WINDOWS = sys.platform in ("win32", "cygwin") IS_MACOS = sys.platform == "darwin" PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) +WEBP_TEST_IMAGES_DIR = os.environ.get("WEBP_TEST_IMAGES_DIR", "") + +# Hacky way of figuring out whether we compiled with libavif/libheif (those are +# currenlty disabled by default) +try: + _decode_avif(torch.arange(10, dtype=torch.uint8)) +except Exception as e: + DECODE_AVIF_ENABLED = "torchvision not compiled with libavif support" not in str(e) + +try: + _decode_heic(torch.arange(10, dtype=torch.uint8)) +except Exception as e: + DECODE_HEIC_ENABLED = "torchvision not compiled with libheif support" not in str(e) def _get_safe_image_name(name): @@ -149,17 +162,6 @@ def test_invalid_exif(tmpdir, size): torch.testing.assert_close(expected, output) -def test_decode_jpeg_errors(): - with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): - decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) - - with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): - decode_jpeg(torch.empty((100,), dtype=torch.float16)) - - with pytest.raises(RuntimeError, match="Not a JPEG file"): - decode_jpeg(torch.empty((100), dtype=torch.uint8)) - - def test_decode_bad_huffman_images(): # sanity check: make sure we can decode the bad Huffman encoding bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg")) @@ -235,10 +237,6 @@ def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun): def test_decode_png_errors(): - with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): - decode_png(torch.empty((), dtype=torch.uint8)) - with pytest.raises(RuntimeError, match="Content is not png"): - decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) with pytest.raises(RuntimeError, match="Out of bound read in decode_png"): decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png"))) with pytest.raises(RuntimeError, match="Content is too small for png"): @@ -864,8 +862,20 @@ def test_decode_gif(tmpdir, name, scripted): torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0) -@pytest.mark.parametrize("decode_fun", (decode_gif, decode_webp)) -def test_decode_gif_webp_errors(decode_fun): +decode_fun_and_match = [ + (decode_png, "Content is not png"), + (decode_jpeg, "Not a JPEG file"), + (decode_gif, re.escape("DGifOpenFileName() failed - 103")), + (decode_webp, "WebPGetFeatures failed."), +] +if DECODE_AVIF_ENABLED: + decode_fun_and_match.append((_decode_avif, "BMFF parsing failed")) +if DECODE_HEIC_ENABLED: + decode_fun_and_match.append((_decode_heic, "Invalid input: No 'ftyp' box")) + + +@pytest.mark.parametrize("decode_fun, match", decode_fun_and_match) +def test_decode_bad_encoded_data(decode_fun, match): encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8) with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"): decode_fun(encoded_data[None]) @@ -873,11 +883,7 @@ def test_decode_gif_webp_errors(decode_fun): decode_fun(encoded_data.float()) with pytest.raises(RuntimeError, match="Input tensor must be contiguous"): decode_fun(encoded_data[::2]) - if decode_fun is decode_gif: - expected_match = re.escape("DGifOpenFileName() failed - 103") - elif decode_fun is decode_webp: - expected_match = "WebPGetFeatures failed." - with pytest.raises(RuntimeError, match=expected_match): + with pytest.raises(RuntimeError, match=match): decode_fun(encoded_data) @@ -890,21 +896,27 @@ def test_decode_webp(decode_fun, scripted): img = decode_fun(encoded_bytes) assert img.shape == (3, 100, 100) assert img[None].is_contiguous(memory_format=torch.channels_last) + img += 123 # make sure image buffer wasn't freed by underlying decoding lib -# This test is skipped because it requires webp images that we're not including -# within the repo. The test images were downloaded from the different pages of -# https://developers.google.com/speed/webp/gallery -# Note that converting an RGBA image to RGB leads to bad results because the -# transparent pixels aren't necessarily set to "black" or "white", they can be -# random stuff. This is consistent with PIL results. -@pytest.mark.skip(reason="Need to download test images first") +# This test is skipped by default because it requires webp images that we're not +# including within the repo. The test images were downloaded manually from the +# different pages of https://developers.google.com/speed/webp/gallery +@pytest.mark.skipif(not WEBP_TEST_IMAGES_DIR, reason="WEBP_TEST_IMAGES_DIR is not set") @pytest.mark.parametrize("decode_fun", (decode_webp, decode_image)) @pytest.mark.parametrize("scripted", (False, True)) @pytest.mark.parametrize( - "mode, pil_mode", ((ImageReadMode.RGB, "RGB"), (ImageReadMode.RGB_ALPHA, "RGBA"), (ImageReadMode.UNCHANGED, None)) + "mode, pil_mode", + ( + # Note that converting an RGBA image to RGB leads to bad results because the + # transparent pixels aren't necessarily set to "black" or "white", they can be + # random stuff. This is consistent with PIL results. + (ImageReadMode.RGB, "RGB"), + (ImageReadMode.RGB_ALPHA, "RGBA"), + (ImageReadMode.UNCHANGED, None), + ), ) -@pytest.mark.parametrize("filename", Path("/home/nicolashug/webp_samples").glob("*.webp")) +@pytest.mark.parametrize("filename", Path(WEBP_TEST_IMAGES_DIR).glob("*.webp"), ids=lambda p: p.name) def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename): encoded_bytes = read_file(filename) if scripted: @@ -915,9 +927,10 @@ def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename) pil_img = Image.open(filename).convert(pil_mode) from_pil = F.pil_to_tensor(pil_img) assert_equal(img, from_pil) + img += 123 # make sure image buffer wasn't freed by underlying decoding lib -@pytest.mark.xfail(reason="AVIF support not enabled yet.") +@pytest.mark.skipif(not DECODE_AVIF_ENABLED, reason="AVIF support not enabled.") @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image)) @pytest.mark.parametrize("scripted", (False, True)) def test_decode_avif(decode_fun, scripted): @@ -927,12 +940,20 @@ def test_decode_avif(decode_fun, scripted): img = decode_fun(encoded_bytes) assert img.shape == (3, 100, 100) assert img[None].is_contiguous(memory_format=torch.channels_last) + img += 123 # make sure image buffer wasn't freed by underlying decoding lib -@pytest.mark.xfail(reason="AVIF and HEIC support not enabled yet.") # Note: decode_image fails because some of these files have a (valid) signature # we don't recognize. We should probably use libmagic.... -@pytest.mark.parametrize("decode_fun", (_decode_avif, _decode_heic)) +decode_funs = [] +if DECODE_AVIF_ENABLED: + decode_funs.append(_decode_avif) +if DECODE_HEIC_ENABLED: + decode_funs.append(_decode_heic) + + +@pytest.mark.skipif(not decode_funs, reason="Built without avif and heic support.") +@pytest.mark.parametrize("decode_fun", decode_funs) @pytest.mark.parametrize("scripted", (False, True)) @pytest.mark.parametrize( "mode, pil_mode", @@ -945,7 +966,7 @@ def test_decode_avif(decode_fun, scripted): @pytest.mark.parametrize( "filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name ) -def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename): +def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, filename): if "reversed_dimg_order" in str(filename): # Pillow properly decodes this one, but we don't (order of parts of the # image is wrong). This is due to a bug that was recently fixed in @@ -996,21 +1017,21 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename) g = make_grid([img, from_pil]) F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png")) - is__decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic" - if mode == ImageReadMode.RGB and not is__decode_heic: + is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic" + if mode == ImageReadMode.RGB and not is_decode_heic: # We don't compare torchvision's AVIF against PIL for RGB because # results look pretty different on RGBA images (other images are fine). # The result on torchvision basically just plainly ignores the alpha # channel, resuting in transparent pixels looking dark. PIL seems to be # using a sort of k-nn thing (Take a look at the resuting images) return - if filename.name == "sofa_grid1x5_420.avif" and is__decode_heic: + if filename.name == "sofa_grid1x5_420.avif" and is_decode_heic: return torch.testing.assert_close(img, from_pil, rtol=0, atol=3) -@pytest.mark.xfail(reason="HEIC support not enabled yet.") +@pytest.mark.skipif(not DECODE_HEIC_ENABLED, reason="HEIC support not enabled yet.") @pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image)) @pytest.mark.parametrize("scripted", (False, True)) def test_decode_heic(decode_fun, scripted): @@ -1020,6 +1041,7 @@ def test_decode_heic(decode_fun, scripted): img = decode_fun(encoded_bytes) assert img.shape == (3, 100, 100) assert img[None].is_contiguous(memory_format=torch.channels_last) + img += 123 # make sure image buffer wasn't freed by underlying decoding lib if __name__ == "__main__": diff --git a/torchvision/csrc/io/image/common.cpp b/torchvision/csrc/io/image/common.cpp new file mode 100644 index 00000000000..16b7ac2f91e --- /dev/null +++ b/torchvision/csrc/io/image/common.cpp @@ -0,0 +1,43 @@ + +#include "common.h" +#include + +namespace vision { +namespace image { + +void validate_encoded_data(const torch::Tensor& encoded_data) { + TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); + TORCH_CHECK( + encoded_data.dtype() == torch::kU8, + "Input tensor must have uint8 data type, got ", + encoded_data.dtype()); + TORCH_CHECK( + encoded_data.dim() == 1 && encoded_data.numel() > 0, + "Input tensor must be 1-dimensional and non-empty, got ", + encoded_data.dim(), + " dims and ", + encoded_data.numel(), + " numels."); +} + +bool should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( + ImageReadMode mode, + bool has_alpha) { + // Return true if the calling decoding function should return a 3D RGB tensor, + // and false if it should return a 4D RGBA tensor. + // This function ignores the requested "grayscale" modes and treats it as + // "unchanged", so it should only used on decoders who don't support grayscale + // outputs. + + if (mode == IMAGE_READ_MODE_RGB) { + return true; + } + if (mode == IMAGE_READ_MODE_RGB_ALPHA) { + return false; + } + // From here we assume mode is "unchanged", even for grayscale ones. + return !has_alpha; +} + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/image_read_mode.h b/torchvision/csrc/io/image/common.h similarity index 65% rename from torchvision/csrc/io/image/image_read_mode.h rename to torchvision/csrc/io/image/common.h index 84425265c34..d81acfda7d4 100644 --- a/torchvision/csrc/io/image/image_read_mode.h +++ b/torchvision/csrc/io/image/common.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace vision { namespace image { @@ -13,5 +14,11 @@ const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2; const ImageReadMode IMAGE_READ_MODE_RGB = 3; const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4; +void validate_encoded_data(const torch::Tensor& encoded_data); + +bool should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( + ImageReadMode mode, + bool has_alpha); + } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_avif.cpp b/torchvision/csrc/io/image/cpu/decode_avif.cpp index 5752f04a448..3cb326e2f11 100644 --- a/torchvision/csrc/io/image/cpu/decode_avif.cpp +++ b/torchvision/csrc/io/image/cpu/decode_avif.cpp @@ -1,4 +1,5 @@ #include "decode_avif.h" +#include "../common.h" #if AVIF_FOUND #include "avif/avif.h" @@ -33,16 +34,7 @@ torch::Tensor decode_avif( // Refer there for more detail about what each function does, and which // structure/data is available after which call. - TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); - TORCH_CHECK( - encoded_data.dtype() == torch::kU8, - "Input tensor must have uint8 data type, got ", - encoded_data.dtype()); - TORCH_CHECK( - encoded_data.dim() == 1, - "Input tensor must be 1-dimensional, got ", - encoded_data.dim(), - " dims."); + validate_encoded_data(encoded_data); DecoderPtr decoder(avifDecoderCreate()); TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder."); @@ -60,6 +52,7 @@ torch::Tensor decode_avif( result == AVIF_RESULT_OK, "avifDecoderParse failed: ", avifResultToString(result)); + printf("avif num images = %d\n", decoder->imageCount); TORCH_CHECK( decoder->imageCount == 1, "Avif file contains more than one image"); @@ -78,18 +71,9 @@ torch::Tensor decode_avif( auto use_uint8 = (decoder->image->depth <= 8); rgb.depth = use_uint8 ? 8 : 16; - if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB && - mode != IMAGE_READ_MODE_RGB_ALPHA) { - // Other modes aren't supported, but we don't error or even warn because we - // have generic entry points like decode_image which may support all modes, - // it just depends on the underlying decoder. - mode = IMAGE_READ_MODE_UNCHANGED; - } - - // If return_rgb is false it means we return rgba - nothing else. auto return_rgb = - (mode == IMAGE_READ_MODE_RGB || - (mode == IMAGE_READ_MODE_UNCHANGED && !decoder->alphaPresent)); + should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( + mode, decoder->alphaPresent); auto num_channels = return_rgb ? 3 : 4; rgb.format = return_rgb ? AVIF_RGB_FORMAT_RGB : AVIF_RGB_FORMAT_RGBA; diff --git a/torchvision/csrc/io/image/cpu/decode_avif.h b/torchvision/csrc/io/image/cpu/decode_avif.h index 0510c2104e5..7feee1adfcb 100644 --- a/torchvision/csrc/io/image/cpu/decode_avif.h +++ b/torchvision/csrc/io/image/cpu/decode_avif.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" namespace vision { namespace image { diff --git a/torchvision/csrc/io/image/cpu/decode_gif.cpp b/torchvision/csrc/io/image/cpu/decode_gif.cpp index 183d42e86a4..f26d37950e3 100644 --- a/torchvision/csrc/io/image/cpu/decode_gif.cpp +++ b/torchvision/csrc/io/image/cpu/decode_gif.cpp @@ -1,5 +1,6 @@ #include "decode_gif.h" #include +#include "../common.h" #include "giflib/gif_lib.h" namespace vision { @@ -34,16 +35,7 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { // Refer over there for more details on the libgif API, API ref, and a // detailed description of the GIF format. - TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); - TORCH_CHECK( - encoded_data.dtype() == torch::kU8, - "Input tensor must have uint8 data type, got ", - encoded_data.dtype()); - TORCH_CHECK( - encoded_data.dim() == 1, - "Input tensor must be 1-dimensional, got ", - encoded_data.dim(), - " dims."); + validate_encoded_data(encoded_data); int error = D_GIF_SUCCEEDED; diff --git a/torchvision/csrc/io/image/cpu/decode_heic.cpp b/torchvision/csrc/io/image/cpu/decode_heic.cpp index 148d6043f10..e245c25f9d7 100644 --- a/torchvision/csrc/io/image/cpu/decode_heic.cpp +++ b/torchvision/csrc/io/image/cpu/decode_heic.cpp @@ -1,4 +1,5 @@ #include "decode_heic.h" +#include "../common.h" #if HEIC_FOUND #include "libheif/heif_cxx.h" @@ -19,26 +20,8 @@ torch::Tensor decode_heic( torch::Tensor decode_heic( const torch::Tensor& encoded_data, ImageReadMode mode) { - TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); - TORCH_CHECK( - encoded_data.dtype() == torch::kU8, - "Input tensor must have uint8 data type, got ", - encoded_data.dtype()); - TORCH_CHECK( - encoded_data.dim() == 1, - "Input tensor must be 1-dimensional, got ", - encoded_data.dim(), - " dims."); - - if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB && - mode != IMAGE_READ_MODE_RGB_ALPHA) { - // Other modes aren't supported, but we don't error or even warn because we - // have generic entry points like decode_image which may support all modes, - // it just depends on the underlying decoder. - mode = IMAGE_READ_MODE_UNCHANGED; - } + validate_encoded_data(encoded_data); - // If return_rgb is false it means we return rgba - nothing else. auto return_rgb = true; int height = 0; @@ -82,8 +65,8 @@ torch::Tensor decode_heic( bit_depth = handle.get_luma_bits_per_pixel(); return_rgb = - (mode == IMAGE_READ_MODE_RGB || - (mode == IMAGE_READ_MODE_UNCHANGED && !handle.has_alpha_channel())); + should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( + mode, handle.has_alpha_channel()); height = handle.get_height(); width = handle.get_width(); diff --git a/torchvision/csrc/io/image/cpu/decode_heic.h b/torchvision/csrc/io/image/cpu/decode_heic.h index 4a23e4c1431..10b414f554d 100644 --- a/torchvision/csrc/io/image/cpu/decode_heic.h +++ b/torchvision/csrc/io/image/cpu/decode_heic.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" namespace vision { namespace image { diff --git a/torchvision/csrc/io/image/cpu/decode_image.h b/torchvision/csrc/io/image/cpu/decode_image.h index f0e66d397ac..f66d47eccd4 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.h +++ b/torchvision/csrc/io/image/cpu/decode_image.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" namespace vision { namespace image { diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp index ec5953e4106..052b98e1be9 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp @@ -1,4 +1,5 @@ #include "decode_jpeg.h" +#include "../common.h" #include "common_jpeg.h" #include "exif.h" @@ -134,12 +135,8 @@ torch::Tensor decode_jpeg( bool apply_exif_orientation) { C10_LOG_API_USAGE_ONCE( "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg"); - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - // Check that the input tensor is 1-dimensional - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); + + validate_encoded_data(data); struct jpeg_decompress_struct cinfo; struct torch_jpeg_error_mgr jerr; diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.h b/torchvision/csrc/io/image/cpu/decode_jpeg.h index e0c9a24c846..7412a46d2ea 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.h +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" namespace vision { namespace image { diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index ac14ae934a4..ede14c1e94a 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -1,4 +1,5 @@ #include "decode_png.h" +#include "../common.h" #include "common_png.h" #include "exif.h" @@ -27,12 +28,8 @@ torch::Tensor decode_png( ImageReadMode mode, bool apply_exif_orientation) { C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png"); - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - // Check that the input tensor is 1-dimensional - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); + + validate_encoded_data(data); auto png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); diff --git a/torchvision/csrc/io/image/cpu/decode_png.h b/torchvision/csrc/io/image/cpu/decode_png.h index 0866711e987..faaffa7ae49 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.h +++ b/torchvision/csrc/io/image/cpu/decode_png.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" namespace vision { namespace image { diff --git a/torchvision/csrc/io/image/cpu/decode_webp.cpp b/torchvision/csrc/io/image/cpu/decode_webp.cpp index bf115c23c41..b202473c039 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.cpp +++ b/torchvision/csrc/io/image/cpu/decode_webp.cpp @@ -1,4 +1,5 @@ #include "decode_webp.h" +#include "../common.h" #if WEBP_FOUND #include "webp/decode.h" @@ -19,16 +20,7 @@ torch::Tensor decode_webp( torch::Tensor decode_webp( const torch::Tensor& encoded_data, ImageReadMode mode) { - TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); - TORCH_CHECK( - encoded_data.dtype() == torch::kU8, - "Input tensor must have uint8 data type, got ", - encoded_data.dtype()); - TORCH_CHECK( - encoded_data.dim() == 1, - "Input tensor must be 1-dimensional, got ", - encoded_data.dim(), - " dims."); + validate_encoded_data(encoded_data); auto encoded_data_p = encoded_data.data_ptr(); auto encoded_data_size = encoded_data.numel(); @@ -40,18 +32,9 @@ torch::Tensor decode_webp( TORCH_CHECK( !features.has_animation, "Animated webp files are not supported."); - if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB && - mode != IMAGE_READ_MODE_RGB_ALPHA) { - // Other modes aren't supported, but we don't error or even warn because we - // have generic entry points like decode_image which may support all modes, - // it just depends on the underlying decoder. - mode = IMAGE_READ_MODE_UNCHANGED; - } - - // If return_rgb is false it means we return rgba - nothing else. auto return_rgb = - (mode == IMAGE_READ_MODE_RGB || - (mode == IMAGE_READ_MODE_UNCHANGED && !features.has_alpha)); + should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( + mode, features.has_alpha); auto decoding_func = return_rgb ? WebPDecodeRGB : WebPDecodeRGBA; auto num_channels = return_rgb ? 3 : 4; diff --git a/torchvision/csrc/io/image/cpu/decode_webp.h b/torchvision/csrc/io/image/cpu/decode_webp.h index 5632ea56ff9..d5c81547c42 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.h +++ b/torchvision/csrc/io/image/cpu/decode_webp.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" namespace vision { namespace image { diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 6314ececef1..2079ca5f919 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -139,7 +139,7 @@ std::vector decode_jpegs_cuda( } CUDAJpegDecoder::CUDAJpegDecoder(const torch::Device& target_device) - : original_device{torch::kCUDA, torch::cuda::current_device()}, + : original_device{torch::kCUDA, c10::cuda::current_device()}, target_device{target_device}, stream{ target_device.has_index() diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 2458a103a3a..6f72d9e35b2 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -1,7 +1,7 @@ #pragma once #include #include -#include "../image_read_mode.h" +#include "../common.h" #if NVJPEG_FOUND #include diff --git a/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h index 3fdf715b00f..8c3ad8f9a9d 100644 --- a/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h @@ -1,7 +1,7 @@ #pragma once #include -#include "../image_read_mode.h" +#include "../common.h" #include "decode_jpegs_cuda.h" #include "encode_jpegs_cuda.h"