-
Notifications
You must be signed in to change notification settings - Fork 26.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pytorch-transformers returns output of 13 layers? #1332
Comments
I am looking at this too and I believe (might be wrong) that the embedding layer sits in the last position. So I guess you should do [-2:-5] |
Hm, I don't think so. The embedding state is passed to the forward function, and that state is used to initialize the |
Hi Bram, Please read the details of The first element of the output tuple of Bert is always the last hidden-state and the full list of hidden-states is the last element of the output tuple in your case. These lines:
should be changed in:
|
Hi Thomas, thank you for your time Apparently a mistake crept into my comment on GitHub. In my code, I do have the correct version, i.e. out = self.bert_model(input_ids=bert_ids, attention_mask=bert_mask)
hidden_states = out[2] The question that I have is, when you then print the length of those hidden states, you get different numbers. print(len(hidden_states))
# 13 for pytorch_transformers, 12 for pytorch_pretrained_bert Going through the source code, it seems that the input hidden state (final hidden state of the embeddings) is included when using I couldn't find this documented anywhere, but I am curious to see the reasoning behind this - since the embedding state is not an encoder state, so it might not be what one expects to get back from the model. On the other hand, it does make it easy for users to get the embeddings. |
Hi Bram, There are a few reasons we did that, one is this great paper by Tenney et al (http://arxiv.org/abs/1905.05950) which use the output of the embeddings as well at the hidden states to study Bert's performances. Another is to have easy access to the embeddings as you mention. |
transformers/pytorch_transformers/modeling_bert.py Lines 350 to 352 in 7c0f2d0
But on line 350-352, it adds the "hidden states" (last layer of embedding) to the "all_hidden_states", so the last item is the embedding output. |
No, by that time the initial
Perhaps the not-so-intuitive part is that the |
You are right, thanks for the clarification! |
@thomwolf Thanks for the clarification. I was looking in all the wrong places, it appears. Particularly, I had expected this in the README's migration part. If you want I can do a small doc pull request for that. Re-opened. Will close after doc change if requested. |
Add small note about the output of hidden states (closes #1332)
📚 Migration
Model I am using (Bert, XLNet....): BertModel
Language I am using the model on (English, Chinese....): English
The problem arise when using:
The tasks I am working on is:
Details of the issue:
I am using pytorch-transformers for the rather unconventional task of regression (one output). In my research I use BERT and I'm planning to try out the other transformers as well. When I started, I got good results with
pytorch-pretrained-bert
. However, running the same code withpytorch-transformers
gives me results that are a lot worse.In the original code, I use the output of the model, and concatenate the last four layers - as was proposed in the BERT paper. The architecture that I used looks like this:
When porting this to
pytorch-transformers
, the main thing was that now we get a tuple back from the model and we have to explicitly ask to get all hidden states back. As such, the converted code looks like this:As I said before, this leads to very different results. Seeding cannot be the issue, since I set all seeds manually in both cases, like this:
I have added the print statements as a sort of debugging and I quickly found that there is a fundamental difference between the two architectures. The hidden_states print statement will yield
12
for pytorch-pretrained-bert and13
forpytorch-transformers
! I am not sure how that relates, but I would assume that this could be the starting point to start looking.I have tried comparing the created models, but in both cases the encoder consists of 12 layers, so I am not sure why
pytorch-transformers
returns 13? What's the extra one?Going through the source code, it seems that the first hidden_state (= last hidden_state from the embeddings) is included. Is that true?
https://github.com/huggingface/pytorch-transformers/blob/7c0f2d0a6a8937063bb310fceb56ac57ce53811b/pytorch_transformers/modeling_bert.py#L340-L352
Even so, since the embeddings would be the first item in all_hidden_states, the last four layers should be the same still. Therefore, I am not sure why there is such a big difference in the results of the above two. If you spot any faults, please advise.
Environment
Checklist
The text was updated successfully, but these errors were encountered: