Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PixtralProcessor to return outputs for all examples in a batch #34321

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/models/pixtral/test_processor_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,39 @@ def prepare_image_inputs(self, batch_size: Optional[int] = None):
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
return [[super().prepare_image_inputs()]] * batch_size

def test_processor_with_batch_of_images_and_text(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_strings = [
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
"USER: [IMG]\nDescribe the image. ASSISTANT:",
]

# Make small for checking image token expansion
processor.image_processor.size = {"longest_edge": 30}
processor.image_processor.patch_size = {"height": 2, "width": 2}

# Test passing in a batch of images and text
inputs = processor(text=prompt_strings, images=[[self.image_0], [self.image_1]], return_tensors="pt")
self.assertIn("input_ids", inputs)
self.assertTrue(len(inputs["input_ids"]) == 2)
self.assertIsInstance(inputs["input_ids"], torch.Tensor)
self.assertIsInstance(inputs["pixel_values"], list)
self.assertTrue(len(inputs["pixel_values"]) == 2)
self.assertIsInstance(inputs["pixel_values"][0], list)
self.assertTrue(len(inputs["pixel_values"][0]) == 1)
self.assertIsInstance(inputs["pixel_values"][0][0], torch.Tensor)
Comment on lines +272 to +279
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit curious why we need all these asserts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey molbap, thanks for the feedback! I added those asserts to make absolutely sure the PixtralProcessor is doing its job correctly. It's like a double-check system. Here's why each one is important:

self.assertIn("input_ids", inputs): This checks if the processor created the "input_ids", which are like a secret code that Pixtral needs to understand the text.
self.assertTrue(len(inputs["input_ids"]) == 2): This makes sure we have the right amount of code, since we're testing with 2 pieces of text.
self.assertIsInstance(inputs["input_ids"], torch.Tensor): This ensures the code is in the right format (a torch.Tensor) that Pixtral can use.

self.assertIsInstance(inputs["pixel_values"], list): This checks that the image information is stored correctly in a list.
self.assertTrue(len(inputs["pixel_values"]) == 2): This confirms we have image information for both images we're using.
self.assertIsInstance(inputs["pixel_values"][0], list) and self.assertTrue(len(inputs["pixel_values"][0]) == 1): These make sure each image's information is organized correctly within the list.
self.assertIsInstance(inputs["pixel_values"][0][0], torch.Tensor): This ensures the actual image data is in the right format (torch.Tensor) for Pixtral.


# fmt: off
input_ids = inputs["input_ids"]
self.assertEqual(
input_ids[0].tolist(),
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
)
self.assertEqual(
input_ids[1].tolist(),
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nDescribe the image. ASSISTANT:"
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on