-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Blip: get/set input embeddings correctly #34152
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤗 thanks
def get_input_embeddings(self): | ||
return self.text_model.get_input_embeddings() | ||
|
||
def set_input_embeddings(self, value): | ||
self.text_model.set_input_embeddings(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if there is a text_config, we could automatically deduce this from the key
which would be here text_model
which to call? (thinking about general api-wise!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, i see that in PreTrainedModel we try to get the method from base_model
and prob we can fallback to that by indicating the base_model_prefix
I am not very sure yet how the prefix is used when loading the model, so lemme quick check that state dict is still correctly loaded
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: yes the idea works and loading happens same way as without the base_model_prefix
. But some of the tests will fail because of the composite nature of BlipConfig
(test_correct_missing_keys
)
I will take this noted, and will add it to my TODO list. But I believe it would force us to refactor from_pretrained
to work well with composite models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay
Also : we might need / want to force return_dict to TRUE, to avoid all the if else |
would make it simpler! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤗
@@ -1771,11 +1771,12 @@ def forward( | |||
decoder_attention_mask=decoder_attention_mask, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
return_dict=True, # toggle for easier access to loss/logits below |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry 😐 realized this would break torch.script or fx export compatibility so maybe False by default ? (I might be wrong tho, but I don't think it's suported)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, torchscript is not supported for BLIP afaik and the tests are disabled therefore. I guess in that case we don't need it to be False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No but you could script only the LM model and not the full model no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added torchscript tests and they are passing currently. FX test cannot be added because the model architecture is not in supported list
I don't think we should do False
be default, as that would add more complexity than before when we passed the actual return_dict
. We'd have to wrap outputs from tuple into the correct ModelOutputClass
manually if return_dict
. If you think we should still not set True
by default let's get to the very first solution I proposed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's go!
* set-get embeds * add tests * fix tests * remove * return dict True * fix tests * why did i remove this * enabel torchscript tests
What does this PR do?
Fixes #34109 and adds
get_input_embeddings
method to the retrieval model. Also fixes the same methods in BLIP model where we should be working with text embeddings. Returning vision embeddings will not be able to resize the vocab sizeAdded tests as those were all skipped and thus we never knew there was an issue