Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put load_image function in image_utils.py & fix image rotation issue #14062

Merged
merged 8 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import List, Union

import numpy as np
import PIL.Image
import PIL.ImageOps

import requests

from .file_utils import _is_torch, is_torch_available

Expand All @@ -35,6 +39,39 @@ def is_torch_tensor(obj):
return _is_torch(obj) if is_torch_available() else False


def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
"""
Loads :obj:`image` to a PIL Image.

Args:
image (:obj:`str` or :obj:`PIL.Image.Image`):
The image to convert to the PIL Image format.

Returns:
:obj:`PIL.Image.Image`: A PIL Image.
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = PIL.Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image


# In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin:
"""
Expand Down
31 changes: 4 additions & 27 deletions src/transformers/pipelines/image_classification.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import os
from typing import List, Union

import requests

from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline
Expand All @@ -11,6 +8,8 @@
if is_vision_available():
from PIL import Image

from ..image_utils import load_image

if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING

Expand Down Expand Up @@ -39,35 +38,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)

@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
image = image.convert("RGB")
return image

def _sanitize_parameters(self, top_k=None):
postprocess_params = {}
if top_k is not None:
postprocess_params["top_k"] = top_k
return {}, {}, postprocess_params

def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
"""
Assign labels to the image(s) passed as inputs.

Expand Down Expand Up @@ -99,7 +76,7 @@ def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwar
return super().__call__(images, **kwargs)

def preprocess(self, image):
image = self.load_image(image)
image = load_image(image)
model_inputs = self.feature_extractor(images=image, return_tensors="pt")
return model_inputs

Expand Down
29 changes: 3 additions & 26 deletions src/transformers/pipelines/image_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import base64
import io
import os
from typing import Any, Dict, List, Union

import numpy as np

import requests

from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline
Expand All @@ -15,6 +12,8 @@
if is_vision_available():
from PIL import Image

from ..image_utils import load_image

if is_torch_available():
import torch

Expand Down Expand Up @@ -49,28 +48,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING)

@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, Image.Image):
pass
else:
raise ValueError(
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
)
image = image.convert("RGB")
return image

def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {}
if "threshold" in kwargs:
Expand Down Expand Up @@ -118,7 +95,7 @@ def get_inference_context(self):
return torch.no_grad

def preprocess(self, image):
image = self.load_image(image)
image = load_image(image)
target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.feature_extractor(images=[image], return_tensors="pt")
inputs["target_size"] = target_size
Expand Down
30 changes: 3 additions & 27 deletions src/transformers/pipelines/object_detection.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import os
from typing import Any, Dict, List, Union

import requests

from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline


if is_vision_available():
from PIL import Image
from ..image_utils import load_image


if is_torch_available():
import torch
Expand Down Expand Up @@ -45,28 +43,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING)

@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, Image.Image):
pass
else:
raise ValueError(
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
)
image = image.convert("RGB")
return image

def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {}
if "threshold" in kwargs:
Expand Down Expand Up @@ -105,7 +81,7 @@ def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:
return super().__call__(*args, **kwargs)

def preprocess(self, image):
image = self.load_image(image)
image = load_image(image)
target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.feature_extractor(images=[image], return_tensors="pt")
inputs["target_size"] = target_size
Expand Down
67 changes: 67 additions & 0 deletions tests/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import unittest

import datasets
import numpy as np

from transformers import is_torch_available, is_vision_available
Expand All @@ -28,6 +29,7 @@
import PIL.Image

from transformers import ImageFeatureExtractionMixin
from transformers.image_utils import load_image


def get_random_image(height, width):
Expand Down Expand Up @@ -367,3 +369,68 @@ def test_center_crop_tensor(self):
# Check result is consistent with PIL.Image.crop
cropped_image = feature_extractor.center_crop(image, size)
self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image))))


@require_vision
class LoadImageTester(unittest.TestCase):
def test_load_img_local(self):
img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png")
img_arr = np.array(img)

self.assertEqual(
img_arr.shape,
(480, 640, 3),
)

def test_load_img_rgba(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")

img = load_image(dataset[0]["file"]) # img with mode RGBA
img_arr = np.array(img)

self.assertEqual(
img_arr.shape,
(512, 512, 3),
)

def test_load_img_la(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")

img = load_image(dataset[1]["file"]) # img with mode LA
img_arr = np.array(img)

self.assertEqual(
img_arr.shape,
(512, 768, 3),
)

def test_load_img_l(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")

img = load_image(dataset[2]["file"]) # img with mode L
img_arr = np.array(img)

self.assertEqual(
img_arr.shape,
(381, 225, 3),
)

def test_load_img_exif_transpose(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
img_file = dataset[3]["file"]

img_without_exif_transpose = PIL.Image.open(img_file)
img_arr_without_exif_transpose = np.array(img_without_exif_transpose)

self.assertEqual(
img_arr_without_exif_transpose.shape,
(333, 500, 3),
)

img_with_exif_transpose = load_image(img_file)
img_arr_with_exif_transpose = np.array(img_with_exif_transpose)

self.assertEqual(
img_arr_with_exif_transpose.shape,
(500, 333, 3),
)