Skip to content

Commit

Permalink
Add LlavaImageProcessor (#33191)
Browse files Browse the repository at this point in the history
* First draft

* Add equivalence test

* Update docstrings

* Add tests

* Use numpy

* Fix tests

* Improve variable names

* Improve docstring

* Add link

* Remove script

* Add copied from

* Address comment

* Add note in docs

* Add docstring, data format

* Improve test

* Add test

* update

* Update src/transformers/models/llava/image_processing_llava.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/llava/image_processing_llava.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* loop once only

---------

Co-authored-by: raushan <raushan@huggingface.co>
Co-authored-by: Raushan Turganbay <raushan.turganbay@alumni.nu.edu.kz>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
  • Loading branch information
4 people authored Jan 21, 2025
1 parent 8e4cedd commit 78f5ee0
Show file tree
Hide file tree
Showing 7 changed files with 667 additions and 4 deletions.
15 changes: 15 additions & 0 deletions docs/source/en/model_doc/llava.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ For multiple turns conversation:
"USER: <image>\n<prompt1> ASSISTANT: <answer1></s>USER: <prompt2> ASSISTANT: <answer2></s>USER: <prompt3> ASSISTANT:"
```

## Note regarding reproducing original implementation

In order to match the logits of the [original implementation](https://github.com/haotian-liu/LLaVA/tree/main), one needs to additionally specify `do_pad=True` when instantiating `LLavaImageProcessor`:

```python
from transformers import LLavaImageProcessor

image_processor = LLavaImageProcessor.from_pretrained("https://huggingface.co/llava-hf/llava-1.5-7b-hf", do_pad=True)
```

### Using Flash Attention 2

Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [Flash Attention 2 section of performance docs](https://huggingface.co/docs/transformers/perf_infer_gpu_one).
Expand All @@ -180,6 +190,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h

[[autodoc]] LlavaConfig

## LlavaImageProcessor

[[autodoc]] LlavaImageProcessor
- preprocess

## LlavaProcessor

[[autodoc]] LlavaProcessor
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,7 @@
_import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"])
_import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"])
_import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"])
_import_structure["models.llava"].append("LlavaImageProcessor")
_import_structure["models.llava_next"].append("LlavaNextImageProcessor")
_import_structure["models.llava_next_video"].append("LlavaNextVideoImageProcessor")
_import_structure["models.llava_onevision"].extend(
Expand Down Expand Up @@ -6334,6 +6335,7 @@
LayoutLMv3ImageProcessor,
)
from .models.levit import LevitFeatureExtractor, LevitImageProcessor
from .models.llava import LlavaImageProcessor
from .models.llava_next import LlavaNextImageProcessor
from .models.llava_next_video import LlavaNextVideoImageProcessor
from .models.llava_onevision import LlavaOnevisionImageProcessor, LlavaOnevisionVideoProcessor
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
("layoutlmv2", ("LayoutLMv2ImageProcessor",)),
("layoutlmv3", ("LayoutLMv3ImageProcessor",)),
("levit", ("LevitImageProcessor",)),
("llava", ("CLIPImageProcessor",)),
("llava", ("LlavaImageProcessor",)),
("llava_next", ("LlavaNextImageProcessor",)),
("llava_next_video", ("LlavaNextVideoImageProcessor",)),
("llava_onevision", ("LlavaOnevisionImageProcessor",)),
Expand Down
436 changes: 436 additions & 0 deletions src/transformers/models/llava/image_processing_llava.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ class LlavaProcessorKwargs(ProcessingKwargs, total=False):

class LlavaProcessor(ProcessorMixin):
r"""
Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
Constructs a LLaVa processor which wraps a LLaVa image processor and a LLaMa tokenizer into a single processor.
[`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`LlavaProcessor`] offers all the functionalities of [`LlavaImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
Args:
image_processor ([`CLIPImageProcessor`], *optional*):
image_processor ([`LlavaImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_vision_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])


class LlavaImageProcessor(metaclass=DummyObject):
_backends = ["vision"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])


class LlavaNextImageProcessor(metaclass=DummyObject):
_backends = ["vision"]

Expand Down
203 changes: 203 additions & 0 deletions tests/models/llava/test_image_processing_llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest
from typing import Tuple, Union

import numpy as np

from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available

from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs


if is_vision_available():
from PIL import Image

from transformers import LlavaImageProcessor


class LlavaImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_pad=True,
do_resize=True,
size=None,
do_center_crop=True,
crop_size=None,
do_normalize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
do_convert_rgb=True,
):
size = size if size is not None else {"shortest_edge": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_pad = do_pad
self.do_resize = do_resize
self.size = size
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb

def prepare_image_processor_dict(self):
return {
"do_pad": self.do_pad,
"do_resize": self.do_resize,
"size": self.size,
"do_center_crop": self.do_center_crop,
"crop_size": self.crop_size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_convert_rgb": self.do_convert_rgb,
}

# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.expected_output_image_shape
def expected_output_image_shape(self, images):
return self.num_channels, self.crop_size["height"], self.crop_size["width"]

# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)


@require_torch
@require_vision
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest with CLIP->Llava
class LlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = LlavaImageProcessor if is_vision_available() else None

def setUp(self):
super().setUp()
self.image_processor_tester = LlavaImageProcessingTester(self)

@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()

# Ignore copy
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))

def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})

image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})

# Ignore copy
def test_padding(self):
"""
LLaVA needs to pad images to square size before processing as per orig implementation.
Checks that image processor pads images correctly given different background colors.
"""

# taken from original implementation: https://github.com/haotian-liu/LLaVA/blob/c121f0432da27facab705978f83c4ada465e46fd/llava/mm_utils.py#L152
def pad_to_square_original(
image: Image.Image, background_color: Union[int, Tuple[int, int, int]] = 0
) -> Image.Image:
width, height = image.size
if width == height:
return image
elif width > height:
result = Image.new(image.mode, (width, width), background_color)
result.paste(image, (0, (width - height) // 2))
return result
else:
result = Image.new(image.mode, (height, height), background_color)
result.paste(image, ((height - width) // 2, 0))
return result

image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)

# test with images in channel-last and channel-first format
for image in image_inputs:
padded_image = image_processor.pad_to_square(image)
padded_image_original = pad_to_square_original(Image.fromarray(image))
padded_image_original = np.array(padded_image_original)

np.testing.assert_allclose(padded_image, padded_image_original)

padded_image = image_processor.pad_to_square(image.transpose(2, 0, 1), input_data_format="channels_first")
padded_image = padded_image.transpose(1, 2, 0)

np.testing.assert_allclose(padded_image, padded_image_original)

# test background color
background_color = (122, 116, 104)
for image in image_inputs:
padded_image = image_processor.pad_to_square(image, background_color=background_color)
padded_image_original = pad_to_square_original(Image.fromarray(image), background_color=background_color)
padded_image_original = np.array(padded_image_original)

np.testing.assert_allclose(padded_image, padded_image_original)

background_color = 122
for image in image_inputs:
padded_image = image_processor.pad_to_square(image, background_color=background_color)
padded_image_original = pad_to_square_original(Image.fromarray(image), background_color=background_color)
padded_image_original = np.array(padded_image_original)

np.testing.assert_allclose(padded_image, padded_image_original)

# background color length should match channel length
with self.assertRaises(ValueError):
padded_image = image_processor.pad_to_square(image_inputs[0], background_color=(122, 104))

with self.assertRaises(ValueError):
padded_image = image_processor.pad_to_square(image_inputs[0], background_color=(122, 104, 0, 0))

@unittest.skip(reason="LLaVa does not support 4 channel images yet")
# Ignore copy
def test_call_numpy_4_channels(self):
pass

0 comments on commit 78f5ee0

Please sign in to comment.