Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Created add_image_text_converter and unit tests #328

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyrit/prompt_converter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter
from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter

from pyrit.prompt_converter.add_image_text_converter import AddImageTextConverter
from pyrit.prompt_converter.add_text_image_converter import AddTextImageConverter
from pyrit.prompt_converter.ascii_art_converter import AsciiArtConverter
from pyrit.prompt_converter.atbash_converter import AtbashConverter
Expand Down Expand Up @@ -34,6 +35,7 @@


__all__ = [
"AddImageTextConverter",
"AddTextImageConverter",
"AsciiArtConverter",
"AtbashConverter",
Expand Down
144 changes: 144 additions & 0 deletions pyrit/prompt_converter/add_image_text_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import base64
import string
from typing import Optional

from PIL import Image, ImageDraw, ImageFont
import textwrap
from io import BytesIO

from pyrit.models.data_type_serializer import data_serializer_factory
from pyrit.models.prompt_request_piece import PromptDataType
from pyrit.prompt_converter import PromptConverter, ConverterResult

logger = logging.getLogger(__name__)


class AddImageTextConverter(PromptConverter):
"""
Adds a string to an image and wraps the text into multiple lines if necessary.
This class is similar to AddImageTextConverter except
we pass in an image file path as an argument to the constructor as opposed to text.

Args:
img_to_add (str): file path of image to add text to
font_name (str, optional): path of font to use. Must be a TrueType font (.ttf). Defaults to "arial.ttf".
color (tuple, optional): color to print text in, using RGB values. Defaults to (0, 0, 0).
font_size (optional, float): size of font to use. Defaults to 15.
x_pos (int, optional): x coordinate to place text in (0 is left most). Defaults to 10.
y_pos (int, optional): y coordinate to place text in (0 is upper most). Defaults to 10.
output_filename (optional, str): filename to store converted image. If not provided a unique UUID will be used
"""

def __init__(
self,
img_to_add: str,
font_name: Optional[str] = "arial.ttf",
color: Optional[tuple[int, int, int]] = (0, 0, 0),
font_size: Optional[int] = 15,
x_pos: Optional[int] = 10,
y_pos: Optional[int] = 10,
output_filename: Optional[str] = None,
):
if not img_to_add:
raise ValueError("Please provide valid image path")
if not font_name.endswith(".ttf"):
raise ValueError("The specified font must be a TrueType font with a .ttf extension")
self._img_to_add = img_to_add
self._font_name = font_name
self._font_size = font_size
self._font = self._load_font()
self._color = color
self._x_pos = x_pos
self._y_pos = y_pos
self._output_name = output_filename

def _load_font(self):
"""
Load the font for a given font name and font size

Returns:
ImageFont.FreeTypeFont or ImageFont.ImageFont: The loaded font object. If the specified font
cannot be loaded, the default font is returned.

Raises:
OSError: If the font resource cannot be loaded, a warning is logged and the default font is used instead.
"""
# Try to load the specified font
try:
font = ImageFont.truetype(self._font_name, self._font_size)
except OSError:
logger.warning(f"Cannot open font resource: {self._font_name}. Using default font.")
font = ImageFont.load_default()
return font

def _add_text_to_image(self, text: str) -> Image.Image:
"""
Adds wrapped text to the image at self._img_to_add.

Args:
text (str): The text to add to the image.

Returns:
Image.Image: The image with added text.
"""
if not text:
raise ValueError("Please provide valid text value")
# Open the image and create a drawing object
image = Image.open(self._img_to_add)
draw = ImageDraw.Draw(image)

# Calculate the maximum width in pixels with margin into account
margin = 5
max_width_pixels = image.size[0] - margin

# Estimate the maximum chars that can fit on a line
alphabet_letters = string.ascii_letters # This gives 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
bbox = draw.textbbox((0, 0), alphabet_letters, font=self._font)
avg_char_width = (bbox[2] - bbox[0]) / len(alphabet_letters)
max_chars_per_line = int(max_width_pixels // avg_char_width)

# Wrap the text
wrapped_text = textwrap.fill(text, width=max_chars_per_line)

# Add wrapped text to image
y_offset = self._y_pos
for line in wrapped_text.split("\n"):
draw.text((self._x_pos, y_offset), line, font=self._font, fill=self._color)
bbox = draw.textbbox((self._x_pos, y_offset), line, font=self._font)
line_height = bbox[3] - bbox[1]
y_offset += line_height

return image

async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult:
"""
Converter that overlays input text on the img_to_add.

Args:
prompt (str): The prompt to be added to the image.
input_type (PromptDataType): type of data
Returns:
ConverterResult: The filename of the converted image as a ConverterResult Object
"""
if not self.input_supported(input_type):
raise ValueError("Input type not supported")

img_serializer = data_serializer_factory(value=self._img_to_add, data_type="image_path")

# Add text to the image
updated_img = self._add_text_to_image(text=prompt)

image_bytes = BytesIO()
mime_type = img_serializer.get_mime_type(self._img_to_add)
image_type = mime_type.split("/")[-1]
updated_img.save(image_bytes, format=image_type)
image_str = base64.b64encode(image_bytes.getvalue())
img_serializer.save_b64_image(data=image_str, output_filename=self._output_name)
return ConverterResult(output_text=img_serializer.value, output_type="image_path")

def input_supported(self, input_type: PromptDataType) -> bool:
return input_type == "text"
130 changes: 130 additions & 0 deletions tests/converter/test_add_image_text_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.


import pytest
import os

from pyrit.prompt_converter import AddImageTextConverter, AddTextImageConverter

from PIL import Image, ImageFont


@pytest.fixture
def image_text_converter_sample_image():
img = Image.new("RGB", (100, 100), color=(125, 125, 125))
img.save("test.png")
return "test.png"


def test_add_image_text_converter_initialization(image_text_converter_sample_image):
converter = AddImageTextConverter(
img_to_add=image_text_converter_sample_image,
font_name="arial.ttf",
color=(255, 255, 255),
font_size=20,
x_pos=10,
y_pos=10,
output_filename="sample_conv_image.png",
)
assert converter._img_to_add == "test.png"
assert converter._font_name == "arial.ttf"
assert converter._color == (255, 255, 255)
assert converter._font_size == 20
assert converter._x_pos == 10
assert converter._y_pos == 10
assert converter._font is not None
assert type(converter._font) is ImageFont.FreeTypeFont
assert converter._output_name == "sample_conv_image.png"
os.remove("test.png")


def test_add_image_text_converter_invalid_font(image_text_converter_sample_image):
with pytest.raises(ValueError):
AddImageTextConverter(
img_to_add=image_text_converter_sample_image, font_name="arial.otf"
) # Invalid font extension
os.remove("test.png")


def test_add_image_text_converter_null_img_to_add():
with pytest.raises(ValueError):
AddImageTextConverter(img_to_add="", font_name="arial.ttf")


def test_add_image_text_converter_fallback_to_default_font(image_text_converter_sample_image, caplog):
AddImageTextConverter(
img_to_add=image_text_converter_sample_image,
font_name="nonexistent_font.ttf",
color=(255, 255, 255),
font_size=20,
x_pos=10,
y_pos=10,
)
assert any(
record.levelname == "WARNING" and "Cannot open font resource" in record.message for record in caplog.records
)
os.remove("test.png")


def test_image_text_converter_add_text_to_image(image_text_converter_sample_image):
converter = AddImageTextConverter(
img_to_add=image_text_converter_sample_image, font_name="arial.ttf", color=(255, 255, 255)
)
image = Image.open("test.png")
pixels_before = list(image.getdata())
updated_image = converter._add_text_to_image("Sample Text!")
pixels_after = list(updated_image.getdata())
assert updated_image
# Check if at least one pixel changed, indicating that text was added
assert pixels_before != pixels_after
os.remove("test.png")


@pytest.mark.asyncio
async def test_add_image_text_converter_invalid_input_text(image_text_converter_sample_image) -> None:
converter = AddImageTextConverter(img_to_add=image_text_converter_sample_image)
with pytest.raises(ValueError):
assert await converter.convert_async(prompt="", input_type="text") # type: ignore
os.remove("test.png")


@pytest.mark.asyncio
async def test_add_image_text_converter_invalid_file_path():
converter = AddImageTextConverter(img_to_add="nonexistent_image.png", font_name="arial.ttf")
with pytest.raises(FileNotFoundError):
assert await converter.convert_async(prompt="Sample Text!", input_type="text") # type: ignore


@pytest.mark.asyncio
async def test_add_image_text_converter_convert_async(image_text_converter_sample_image) -> None:
converter = AddImageTextConverter(
img_to_add=image_text_converter_sample_image, output_filename="sample_conv_image.png"
)
converted_image = await converter.convert_async(prompt="Sample Text!", input_type="text")
assert converted_image
assert converted_image.output_text == "sample_conv_image.png"
assert converted_image.output_type == "image_path"
assert os.path.exists(converted_image.output_text)
os.remove(converted_image.output_text)
os.remove("test.png")


def test_text_image_converter_input_supported(image_text_converter_sample_image):
converter = AddImageTextConverter(img_to_add=image_text_converter_sample_image)
assert converter.input_supported("image_path") is False
assert converter.input_supported("text") is True


@pytest.mark.asyncio
async def test_add_image_text_converter_equal_to_add_text_image(image_text_converter_sample_image) -> None:
converter = AddImageTextConverter(img_to_add=image_text_converter_sample_image)
converted_image = await converter.convert_async(prompt="Sample Text!", input_type="text")
text_image_converter = AddTextImageConverter(text_to_add="Sample Text!")
converted_text_image = await text_image_converter.convert_async(prompt="test.png", input_type="image_path")
pixels_image_text = list(Image.open(converted_image.output_text).getdata())
pixels_text_image = list(Image.open(converted_text_image.output_text).getdata())
assert pixels_image_text == pixels_text_image
os.remove(converted_image.output_text)
os.remove("test.png")
os.remove(converted_text_image.output_text)
Loading