Skip to content

Commit

Permalink
Support loading base64 images in pipelines (#25633)
Browse files Browse the repository at this point in the history
* support loading base64 images

* add test

* mention in docs

* remove the logging

* sort imports

* update error message

* Update tests/utils/test_image_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* restructure to catch base64 exception

* doesn't like the newline

* download files

* format

* optimize imports

* guess it needs a space?

* support loading base64 images

* add test

* remove the logging

* sort imports

* restructure to catch base64 exception

* doesn't like the newline

* download files

* optimize imports

* guess it needs a space?

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
InventivetalentDev and amyeroberts authored Aug 29, 2023
1 parent ce2d4bc commit dbc16f4
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/pipeline_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ page.

Using a [`pipeline`] for vision tasks is practically identical.

Specify your task and pass your image to the classifier. The image can be a link or a local path to the image. For example, what species of cat is shown below?
Specify your task and pass your image to the classifier. The image can be a link, a local path or a base64-encoded image. For example, what species of cat is shown below?

![pipeline-cat-chonk](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg)

Expand Down
18 changes: 14 additions & 4 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import os
from io import BytesIO
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -298,14 +300,22 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
if image.startswith("data:image/"):
image = image.split(",")[1]

# Try to load as base64
try:
b64 = base64.b64decode(image, validate=True)
image = PIL.Image.open(BytesIO(b64))
except Exception as e:
raise ValueError(
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
"Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
Expand Down
37 changes: 37 additions & 0 deletions tests/utils/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

import datasets
import numpy as np
import pytest
from huggingface_hub.file_download import http_get
from requests import ReadTimeout

from tests.pipelines.test_pipelines_document_question_answering import INVOICE_URL
Expand Down Expand Up @@ -500,6 +503,40 @@ def test_load_img_local(self):
(480, 640, 3),
)

def test_load_img_base64_prefix(self):
try:
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get(
"https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_0.txt", f
)

with open(tmp_file, encoding="utf-8") as b64:
img = load_image(b64.read())
img_arr = np.array(img)

finally:
os.remove(tmp_file)

self.assertEqual(img_arr.shape, (64, 32, 3))

def test_load_img_base64(self):
try:
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get(
"https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_1.txt", f
)

with open(tmp_file, encoding="utf-8") as b64:
img = load_image(b64.read())
img_arr = np.array(img)

finally:
os.remove(tmp_file)

self.assertEqual(img_arr.shape, (64, 32, 3))

def test_load_img_rgba(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")

Expand Down

0 comments on commit dbc16f4

Please sign in to comment.