Skip to content

Commit

Permalink
fix: preserve quality and optimize transfer of prompt images (#570)
Browse files Browse the repository at this point in the history
* fix: preserve quality and optimize transfer of prompt images

* Move numpy-images to their own test case.

Change-Id: Ie6b02c7647487c1df9d4e70e9b8eed70dc8b8fe3

* Format with black

Change-Id: I04550a89eed9bb21c0a8f6f9b6ab76b8b0f41270

---------

Co-authored-by: Mark Daoust <markdaoust@google.com>
  • Loading branch information
PicardParis and MarkDaoust authored Sep 24, 2024
1 parent 8f7f5cb commit 6c8dad1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 40 deletions.
68 changes: 31 additions & 37 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io
import inspect
import mimetypes
import pathlib
import typing
from typing import Any, Callable, Union
from typing_extensions import TypedDict
Expand All @@ -30,15 +31,15 @@

if typing.TYPE_CHECKING:
import PIL.Image
import PIL.PngImagePlugin
import PIL.ImageFile
import IPython.display

IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
else:
IMAGE_TYPES = ()
try:
import PIL.Image
import PIL.PngImagePlugin
import PIL.ImageFile

IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)
except ImportError:
Expand Down Expand Up @@ -72,46 +73,39 @@
]


def pil_to_blob(img):
# When you load an image with PIL you get a subclass of PIL.Image
# The subclass knows what file type it was loaded from it has a `.format` class attribute
# and the `get_format_mimetype` method. Convert these back to the same file type.
#
# The base image class doesn't know its file type, it just knows its mode.
# RGBA converts to PNG easily, P[allet] converts to GIF, RGB to GIF.
# But for anything else I'm not going to bother mapping it out (for now) let's just convert to RGB and send it.
#
# References:
# - file formats: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html
# - image modes: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes

bytesio = io.BytesIO()

get_mime = getattr(img, "get_format_mimetype", None)
if get_mime is not None:
# If the image is created from a file, convert back to the same file type.
img.save(bytesio, format=img.format)
mime_type = img.get_format_mimetype()
elif img.mode == "RGBA":
img.save(bytesio, format="PNG")
mime_type = "image/png"
elif img.mode == "P":
img.save(bytesio, format="GIF")
mime_type = "image/gif"
else:
if img.mode != "RGB":
img = img.convert("RGB")
img.save(bytesio, format="JPEG")
mime_type = "image/jpeg"
bytesio.seek(0)
data = bytesio.read()
return protos.Blob(mime_type=mime_type, data=data)
def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
# If the image is a local file, return a file-based blob without any modification.
# Otherwise, return a lossless WebP blob (same quality with optimized size).
def file_blob(image: PIL.Image.Image) -> protos.Blob | None:
if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None:
return None
filename = str(image.filename)
if not pathlib.Path(filename).is_file():
return None

mime_type = image.get_format_mimetype()
image_bytes = pathlib.Path(filename).read_bytes()

return protos.Blob(mime_type=mime_type, data=image_bytes)

def webp_blob(image: PIL.Image.Image) -> protos.Blob:
# Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
image_io = io.BytesIO()
image.save(image_io, format="webp", lossless=True)
image_io.seek(0)

mime_type = "image/webp"
image_bytes = image_io.read()

return protos.Blob(mime_type=mime_type, data=image_bytes)

return file_blob(image) or webp_blob(image)


def image_to_blob(image) -> protos.Blob:
if PIL is not None:
if isinstance(image, PIL.Image.Image):
return pil_to_blob(image)
return _pil_to_blob(image)

if IPython is not None:
if isinstance(image, IPython.display.Image):
Expand Down
15 changes: 12 additions & 3 deletions tests/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,20 @@ class HasEnum:


class UnitTests(parameterized.TestCase):

@parameterized.named_parameters(
["PIL", PIL.Image.open(TEST_PNG_PATH)],
["RGBA", PIL.Image.fromarray(np.zeros([6, 6, 4], dtype=np.uint8))],
["RGB", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8))],
["P", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8)).convert("P")],
)
def test_numpy_to_blob(self, image):
blob = content_types.image_to_blob(image)
self.assertIsInstance(blob, protos.Blob)
self.assertEqual(blob.mime_type, "image/webp")
self.assertStartsWith(blob.data, b"RIFF \x00\x00\x00WEBPVP8L")

@parameterized.named_parameters(
["PIL", PIL.Image.open(TEST_PNG_PATH)],
["IPython", IPython.display.Image(filename=TEST_PNG_PATH)],
)
def test_png_to_blob(self, image):
Expand All @@ -96,7 +107,6 @@ def test_png_to_blob(self, image):

@parameterized.named_parameters(
["PIL", PIL.Image.open(TEST_JPG_PATH)],
["RGB", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8))],
["IPython", IPython.display.Image(filename=TEST_JPG_PATH)],
)
def test_jpg_to_blob(self, image):
Expand All @@ -107,7 +117,6 @@ def test_jpg_to_blob(self, image):

@parameterized.named_parameters(
["PIL", PIL.Image.open(TEST_GIF_PATH)],
["P", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8)).convert("P")],
["IPython", IPython.display.Image(filename=TEST_GIF_PATH)],
)
def test_gif_to_blob(self, image):
Expand Down

0 comments on commit 6c8dad1

Please sign in to comment.