Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blip: get/set input embeddings correctly #34152

Merged
merged 13 commits into from
Nov 1, 2024
31 changes: 24 additions & 7 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,12 @@ def __init__(self, config: BlipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.text_model.get_input_embeddings()

def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)
Comment on lines +798 to +802
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if there is a text_config, we could automatically deduce this from the key which would be here text_model which to call? (thinking about general api-wise!)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, i see that in PreTrainedModel we try to get the method from base_model and prob we can fallback to that by indicating the base_model_prefix

I am not very sure yet how the prefix is used when loading the model, so lemme quick check that state dict is still correctly loaded

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: yes the idea works and loading happens same way as without the base_model_prefix. But some of the tests will fail because of the composite nature of BlipConfig (test_correct_missing_keys)

I will take this noted, and will add it to my TODO list. But I believe it would force us to refactor from_pretrained to work well with composite models

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay


@add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
Expand Down Expand Up @@ -1053,8 +1059,11 @@ def __init__(self, config: BlipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def get_input_embeddings(self):
return self.text_decoder.get_input_embeddings()

def set_input_embeddings(self, value):
self.text_decoder.set_input_embeddings(value)

@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BlipForConditionalGenerationModelOutput, config_class=BlipVisionConfig)
Expand Down Expand Up @@ -1117,7 +1126,8 @@ def forward(
)

if not return_dict:
outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
outputs = (outputs[0], outputs[1]) if labels is not None else (outputs[0],)
outputs += (image_embeds, vision_outputs[0]) + vision_outputs[2:]
return tuple(output for output in outputs if output is not None)

return BlipForConditionalGenerationModelOutput(
Expand Down Expand Up @@ -1232,8 +1242,12 @@ def __init__(self, config: BlipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def set_input_embeddings(self, value):
self.text_encoder.set_input_embeddings(value)

def get_input_embeddings(self):
# This will return shared embeddings if they are shared else specific to encoder.
return self.text_encoder.get_input_embeddings()

@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
Expand Down Expand Up @@ -1474,8 +1488,11 @@ def __init__(self, config: BlipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def get_input_embeddings(self):
return self.text_encoder.get_input_embeddings()

def set_input_embeddings(self, value):
self.text_encoder.set_input_embeddings(value)

@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,12 @@ def __init__(self, config):
self.cls = BlipTextOnlyMLMHead(config)
self.label_smoothing = config.label_smoothing

def get_input_embeddings(self):
return self.bert.get_input_embeddings()

def set_input_embeddings(self, new_embeddings):
self.bert.set_input_embeddings(new_embeddings)

def get_output_embeddings(self):
return self.cls.predictions.decoder

Expand Down
26 changes: 20 additions & 6 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,11 +1768,12 @@ def forward(
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True, # toggle for easier access to loss/logits below
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry 😐 realized this would break torch.script or fx export compatibility so maybe False by default ? (I might be wrong tho, but I don't think it's suported)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, torchscript is not supported for BLIP afaik and the tests are disabled therefore. I guess in that case we don't need it to be False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No but you could script only the LM model and not the full model no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added torchscript tests and they are passing currently. FX test cannot be added because the model architecture is not in supported list

I don't think we should do False be default, as that would add more complexity than before when we passed the actual return_dict. We'd have to wrap outputs from tuple into the correct ModelOutputClass manually if return_dict. If you think we should still not set True by default let's get to the very first solution I proposed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay sounds good!

labels=labels,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
loss = outputs.loss
logits = outputs.logits
outputs = outputs.to_tuple() if not return_dict else outputs

if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
Expand Down Expand Up @@ -1810,6 +1811,12 @@ def __init__(self, config: Blip2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

@add_start_docstrings_to_model_forward(BLIP_2_TEXT_WITH_PROJECTION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2Config)
def forward(
Expand Down Expand Up @@ -2233,11 +2240,12 @@ def forward(
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True, # toggle for easier access to loss/logits below
labels=labels,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
loss = outputs.loss
logits = outputs.logits
outputs = outputs.to_tuple() if not return_dict else outputs

if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
Expand Down Expand Up @@ -2389,6 +2397,12 @@ def __init__(self, config: Blip2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

@add_start_docstrings_to_model_forward(BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2Config)
def forward(
Expand Down
12 changes: 5 additions & 7 deletions tests/models/blip/test_modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False

def setUp(self):
Expand Down Expand Up @@ -738,7 +738,6 @@ def prepare_config_and_inputs_for_common(self):
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"labels": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
Expand Down Expand Up @@ -787,10 +786,10 @@ def prepare_config_and_inputs_for_common(self):
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"labels": input_ids,
"decoder_input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"labels": input_ids,
}
return config, inputs_dict

Expand All @@ -802,7 +801,7 @@ class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand All @@ -811,7 +810,6 @@ def setUp(self):

def _prepare_inputs_for_vqa(self):
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict["labels"] = inputs_dict["input_ids"]
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
inputs_dict.pop("return_loss")
return inputs_dict
Expand Down Expand Up @@ -882,7 +880,7 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand Down Expand Up @@ -1110,7 +1108,7 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand Down
Loading
Loading