Skip to content

Commit

Permalink
Expand inputs in processors for VLMs (#30962)
Browse files Browse the repository at this point in the history
* let it be

* draft

* should not have changed

* add warnings

* fix & add tests

* fix tests

* ipnuts embeds cannot be passed with pixels

* more updates

* paligemma ready!

* minor typos

* update blip-2

* fix tests & raise error

* docstring

* add blip2 test

* tmp

* add image seq length to config

* update docstring

* delete

* fix tests

* fix blip

* fix paligemma

* out-of-place scatter

* add llava-next-video

* Update src/transformers/models/blip_2/modeling_blip_2.py

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

* remove tmp

* codestyle

* nits

* more nits

* remove overriding in tests

* comprehension when merging video

* fix-copies

* revert changes for embeds test

* fix tests after making comprehension

* Update src/transformers/models/blip_2/processing_blip_2.py

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

* Update src/transformers/models/blip_2/processing_blip_2.py

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

* more updates

* fix tests

---------

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
  • Loading branch information
zucchini-nlp and molbap authored Aug 13, 2024
1 parent 2a5a6ad commit a29eabd
Show file tree
Hide file tree
Showing 37 changed files with 1,945 additions and 796 deletions.
13 changes: 12 additions & 1 deletion src/transformers/models/blip_2/configuration_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ class Blip2Config(PretrainedConfig):
num_query_tokens (`int`, *optional*, defaults to 32):
The number of query tokens passed through the Transformer.
image_token_index (`int`, *optional*):
Token index of special image token.
kwargs (*optional*):
Dictionary of keyword arguments.
Expand Down Expand Up @@ -299,7 +301,15 @@ class Blip2Config(PretrainedConfig):

model_type = "blip-2"

def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
def __init__(
self,
vision_config=None,
qformer_config=None,
text_config=None,
num_query_tokens=32,
image_token_index=None,
**kwargs,
):
super().__init__(**kwargs)

if vision_config is None:
Expand All @@ -323,6 +333,7 @@ def __init__(self, vision_config=None, qformer_config=None, text_config=None, nu
self.is_encoder_decoder = self.text_config.is_encoder_decoder

self.num_query_tokens = num_query_tokens
self.image_token_index = image_token_index
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
self.initializer_factor = 1.0
Expand Down
55 changes: 41 additions & 14 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,12 +1767,25 @@ def forward(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)

if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
expected_device = language_model_attention_mask.device
attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1)

# if the model already has "image_token_index" then the input is expanded to account for image embeds
# otherwise we expand manually by concating
if getattr(self.config, "image_token_index", None) is not None:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else:
logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
attention_mask = torch.cat(
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
)

if self.config.use_decoder_only_language_model:
outputs = self.language_model(
Expand Down Expand Up @@ -1876,20 +1889,34 @@ def generate(
.repeat(batch_size, 1)
.to(image_embeds.device)
)
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)

# concatenate query embeddings with prompt embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
# if the model already has "image_token_index" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating
if getattr(self.config, "image_token_index", None) is not None:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
else:
logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
attention_mask = torch.cat(
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
)

# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
# -1 is to account for the prepended BOS after `generate.`
# TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
# -1 is to account for the prepended BOS after `generate.`
# TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = (
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
Expand Down
59 changes: 50 additions & 9 deletions src/transformers/models/blip_2/processing_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,18 @@

from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ...tokenization_utils_base import (
AddedToken,
BatchEncoding,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from ...utils import TensorType, logging


logger = logging.get_logger(__name__)


class Blip2Processor(ProcessorMixin):
Expand All @@ -36,20 +46,24 @@ class Blip2Processor(ProcessorMixin):
An instance of [`BlipImageProcessor`]. The image processor is a required input.
tokenizer (`AutoTokenizer`):
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
num_query_tokens (`int`, *optional*):
Number of tokens used by the Qformer as queries, should be same as in model's config.
"""

attributes = ["image_processor", "tokenizer"]
valid_kwargs = []
valid_kwargs = ["num_query_tokens"]
image_processor_class = "BlipImageProcessor"
tokenizer_class = "AutoTokenizer"

# Copied from transformers.models.blip.processing_blip.BlipProcessor.__init__
def __init__(self, image_processor, tokenizer, **kwargs):
def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
tokenizer.return_token_type_ids = False
self.current_processor = image_processor
self.image_token = AddedToken("<image>", normalized=False, special=True)
tokenizer.add_tokens([self.image_token], special_tokens=True)
self.num_query_tokens = num_query_tokens

super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor

# Copied from transformers.models.blip.processing_blip.BlipProcessor.__call__
def __call__(
self,
images: ImageInput = None,
Expand Down Expand Up @@ -106,7 +120,13 @@ def __call__(
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)

if text is not None:
text_encoding = self.tokenizer(
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

text_encoding = {}
_text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
Expand All @@ -121,9 +141,30 @@ def __call__(
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
return_tensors=None, # hardcode "None" here for prepending image tokens
**kwargs,
)

# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
# because BLIP expects image tokens to be at the beginning even before BOS token
if self.num_query_tokens is not None:
image_tokens = self.image_token.content * self.num_query_tokens
image_token_encoding = self.tokenizer([image_tokens], add_special_tokens=False, return_tensors=None)
for k in _text_encoding:
text_encoding[k] = [
img_encoding + txt_encoding
for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k])
]
else:
text_encoding = _text_encoding
logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)

# cast to desired return tensors type
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
else:
text_encoding = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ class InstructBlipConfig(PretrainedConfig):
num_query_tokens (`int`, *optional*, defaults to 32):
The number of query tokens passed through the Transformer.
image_token_index (`int`, *optional*):
Token index of special image token.
kwargs (*optional*):
Dictionary of keyword arguments.
Expand Down Expand Up @@ -304,7 +306,15 @@ class InstructBlipConfig(PretrainedConfig):

model_type = "instructblip"

def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
def __init__(
self,
vision_config=None,
qformer_config=None,
text_config=None,
num_query_tokens=32,
image_token_index=None,
**kwargs,
):
super().__init__(**kwargs)

if vision_config is None:
Expand All @@ -328,6 +338,7 @@ def __init__(self, vision_config=None, qformer_config=None, text_config=None, nu
self.is_encoder_decoder = self.text_config.is_encoder_decoder

self.num_query_tokens = num_query_tokens
self.image_token_index = image_token_index
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
self.initializer_factor = 1.0
Expand Down
51 changes: 39 additions & 12 deletions src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,12 +1453,24 @@ def forward(
)

inputs_embeds = self.language_model.get_input_embeddings()(input_ids)

inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)

if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1)

# if the model already has "image_token_index" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating
if getattr(self.config, "image_token_index", None) is not None:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
else:
logger.warning_once(
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
attention_mask = torch.cat(
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
)

if self.config.use_decoder_only_language_model:
outputs = self.language_model(
Expand Down Expand Up @@ -1580,17 +1592,32 @@ def generate(
)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)

# concatenate query embeddings with prompt embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)

# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
# -1 is to account for the prepended BOS after `generate.`
if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
# if the model already has "image_token_index" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating
if getattr(self.config, "image_token_index", None) is not None:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
else:
logger.warning_once(
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
attention_mask = torch.cat(
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
)

# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
# -1 is to account for the prepended BOS after `generate.`
if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = (
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
Expand Down
Loading

0 comments on commit a29eabd

Please sign in to comment.