Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blip: get/set input embeddings correctly #34152

Merged
merged 13 commits into from
Nov 1, 2024
31 changes: 24 additions & 7 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,12 @@ def __init__(self, config: BlipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.text_model.get_input_embeddings()

def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)
Comment on lines +798 to +802
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if there is a text_config, we could automatically deduce this from the key which would be here text_model which to call? (thinking about general api-wise!)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, i see that in PreTrainedModel we try to get the method from base_model and prob we can fallback to that by indicating the base_model_prefix

I am not very sure yet how the prefix is used when loading the model, so lemme quick check that state dict is still correctly loaded

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: yes the idea works and loading happens same way as without the base_model_prefix. But some of the tests will fail because of the composite nature of BlipConfig (test_correct_missing_keys)

I will take this noted, and will add it to my TODO list. But I believe it would force us to refactor from_pretrained to work well with composite models

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay


@add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
Expand Down Expand Up @@ -1053,8 +1059,11 @@ def __init__(self, config: BlipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def get_input_embeddings(self):
return self.text_decoder.get_input_embeddings()

def set_input_embeddings(self, value):
self.text_decoder.set_input_embeddings(value)

@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BlipForConditionalGenerationModelOutput, config_class=BlipVisionConfig)
Expand Down Expand Up @@ -1117,7 +1126,8 @@ def forward(
)

if not return_dict:
outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
outputs = (outputs[0], outputs[1]) if labels is not None else (outputs[0],)
outputs += (image_embeds, vision_outputs[0]) + vision_outputs[2:]
return tuple(output for output in outputs if output is not None)

return BlipForConditionalGenerationModelOutput(
Expand Down Expand Up @@ -1232,8 +1242,12 @@ def __init__(self, config: BlipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def set_input_embeddings(self, value):
self.text_encoder.set_input_embeddings(value)

def get_input_embeddings(self):
# This will return shared embeddings if they are shared else specific to encoder.
return self.text_encoder.get_input_embeddings()

@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
Expand Down Expand Up @@ -1474,8 +1488,11 @@ def __init__(self, config: BlipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def get_input_embeddings(self):
return self.text_encoder.get_input_embeddings()

def set_input_embeddings(self, value):
self.text_encoder.set_input_embeddings(value)

@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,12 @@ def __init__(self, config):
self.cls = BlipTextOnlyMLMHead(config)
self.label_smoothing = config.label_smoothing

def get_input_embeddings(self):
return self.bert.get_input_embeddings()

def set_input_embeddings(self, new_embeddings):
self.bert.set_input_embeddings(new_embeddings)

def get_output_embeddings(self):
return self.cls.predictions.decoder

Expand Down
28 changes: 24 additions & 4 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,8 +1774,12 @@ def forward(
return_dict=return_dict,
labels=labels,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
if labels is not None:
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
else:
loss = None
logits = outputs.logits if return_dict else outputs[0]

if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
Expand Down Expand Up @@ -1813,6 +1817,12 @@ def __init__(self, config: Blip2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

@add_start_docstrings_to_model_forward(BLIP_2_TEXT_WITH_PROJECTION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2Config)
def forward(
Expand Down Expand Up @@ -2243,8 +2253,12 @@ def forward(
return_dict=return_dict,
labels=labels,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
if labels is not None:
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
else:
loss = None
logits = outputs.logits if return_dict else outputs[0]

if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
Expand Down Expand Up @@ -2396,6 +2410,12 @@ def __init__(self, config: Blip2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

@add_start_docstrings_to_model_forward(BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2Config)
def forward(
Expand Down
12 changes: 5 additions & 7 deletions tests/models/blip/test_modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False

def setUp(self):
Expand Down Expand Up @@ -737,7 +737,6 @@ def prepare_config_and_inputs_for_common(self):
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"labels": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
Expand Down Expand Up @@ -786,10 +785,10 @@ def prepare_config_and_inputs_for_common(self):
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"labels": input_ids,
"decoder_input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"labels": input_ids,
}
return config, inputs_dict

Expand All @@ -801,7 +800,7 @@ class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand All @@ -810,7 +809,6 @@ def setUp(self):

def _prepare_inputs_for_vqa(self):
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict["labels"] = inputs_dict["input_ids"]
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
inputs_dict.pop("return_loss")
return inputs_dict
Expand Down Expand Up @@ -881,7 +879,7 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand Down Expand Up @@ -1109,7 +1107,7 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand Down
12 changes: 5 additions & 7 deletions tests/models/blip_2/test_modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,6 @@ def prepare_config_and_inputs_for_common(self):
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
}
return config, inputs_dict

Expand Down Expand Up @@ -692,7 +691,6 @@ def prepare_config_and_inputs_for_common(self):
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"labels": labels,
}
return config, inputs_dict

Expand All @@ -712,7 +710,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand Down Expand Up @@ -818,7 +816,7 @@ def test_get_text_features(self):
def test_get_image_features(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]

for key in keys_to_pop:
inputs_dict.pop(key)
Expand All @@ -838,7 +836,7 @@ def test_get_image_features(self):
def test_get_qformer_features(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]

for key in keys_to_pop:
inputs_dict.pop(key)
Expand Down Expand Up @@ -948,7 +946,7 @@ class Blip2TextModelWithProjectionTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_head_masking = False

test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand Down Expand Up @@ -1272,7 +1270,7 @@ class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand Down
2 changes: 1 addition & 1 deletion tests/models/instructblip/test_modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,7 @@ def test_resize_tokens_embeddings(self):
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict.pop("labels", None)

for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
Expand Down Expand Up @@ -1991,6 +1992,7 @@ def test_resize_embeddings_untied(self):

original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
original_config.tie_word_embeddings = False
inputs_dict.pop("labels", None)

# if model cannot untied embeddings -> leave test
if original_config.tie_word_embeddings:
Expand Down
Loading