-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update past_key_values
in GPT-2
#9596
Update past_key_values
in GPT-2
#9596
Conversation
CircleCI error messages says as below. In
In
|
Is there a difference between I first thought it might be a difference between the Causal language model and the Seq2Seq language model, but it seems that both And as for the contents of transformers/src/transformers/models/bart/modeling_bart.py Lines 1236 to 1244 in 236cc36
|
I've updated
transformers/src/transformers/models/xlnet/modeling_xlnet.py Lines 581 to 607 in 236cc36
It seems |
Hey @forest1988, You're PR looks very nice! Yes, it is expected that
def _reorder_cache(self, past, beam_idx):
raise NotImplementedError(...) |
I've just updated |
This way it's much cleaner and correct :-) The reason I'm proposing this change is that the def _reorder_cache(self, past, beam_idx):
raise NotImplementedError(f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}")
|
I think this should solve the problems, let me know if you need more help :-) |
Thank you for your advice! I'll update |
89ee453
to
d04b10c
Compare
Thanks to your kind advice, I could solve the problem of The last one remaining bug is:
I think I should modify
|
All checks have passed! However, in the documentation of |
past_key_values
in GPT-2past_key_values
in GPT-2
called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every | ||
generation step. | ||
|
||
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove those lines and past_key_values
above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cleaned it as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR looks very nice - thanks so much for taking the time to tackle this @forest1988 . Let's wait a bit to see how to proceed with gradient_checkpointing
in GPT2 as this question will come up more often. IMO, use_cache
should always be False
for training so either we update all use_cache
in the models with a use_cache= not self.is_training and (use_cache if use_cache is not None else self.config.use_cache)
or we force it somehow in the Trainer. Similarly gradient_checkpointing
should never be set to True when the model is not training IMO (we could also automatically disable this using self.training
). Let's see what @LysandreJik and @sgugger think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a part of the library I'm very familiar with, so the changes look okay on my side, but I'm no expert.
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes look good to me! Thanks for taking care of it @forest1988.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @forest1988,
I hope it's fine for you that I went into the PR to do some final fixes. Thanks a lot for cleaning this up :-)
Of course! Thank you for adding fixes to make this PR more valuable! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your commit looks good to me @patrickvonplaten! Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new changes look good to me, thanks!
Awesome, merging - great job @forest1988 ! |
Thank you for your advice and encouraging comments! |
@@ -232,7 +232,7 @@ def forward( | |||
value = torch.cat((past_value, value), dim=-2) | |||
|
|||
if use_cache is True: | |||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking | |||
present = (key.transpose(-2, -1), value) # transpose to have same shapes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the reason for the recent failure of the slow test:
RUN_SLOW=1 pytest tests/test_onnx.py::OnnxExportTestCase::test_export_pytorch
Can you fix the onnx part easily? @mfuntowicz @Narsil
What does this PR do?
It seems GPT-2 and BartDecoder has a different style of
past_key_values
.Advised by @patrickvonplaten,
I opened this PR to change GPT-2's cache format from a single tensor to a tuple of 2 tensors.
Once this problem is solved, it is expected that
past_key_values
in GPT-2 will be handled in the same way as in Bart.Sorry there remain some errors. This PR is [WIP].
I would appreciate your advice on how to update
generation_utils.py
.Can I modify
_reorder_cache
so that past is replaced from Tuple[torch.Tensor] to Tuple[Tuple[torch.Tensor]],or should I consider other output variations, output.mem and outputs.past_buckets_states?
Fixes #9391
From patrickvonplaten:
This PR cleans the
_reorder_cache
logic. Now_reorcher_cache
defaults to an erroneousNotImplementedError
ingeneration_utils.py
forcing the model to implement its corresponding_rerorder_cache
it themodeling_...py
file itself. This is cleaner as_reorder_cache
strongly differs from model to model. In addition, this PR makes sure thatgradient_checkpointing
can only be used if the model is in training mode and makes sure thatuse_cache
is disabled when training andgradient_checkpointing
is enabled to prevent errors.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
GPT2: @LysandreJik, @patrickvonplaten