Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix warning trigger for embed_positions when loading xglm #25798

Merged
merged 2 commits into from
Aug 29, 2023

Conversation

MattYoon
Copy link
Contributor

What does this PR do?

Fixes #25797

Warning no longer triggers when loading XGLM from the hub. I've followed @younesbelkada 's suggestion of making the problematic module non-persistent.

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained('facebook/xglm-564M')
# no warning

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@younesbelkada @ArthurZucker

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MattYoon
Thanks a lot for the PR ! A test that checks conversion between tensorflow and pytorch model is failing. I think we can fix it by adding a _keys_to_ignore_on_load_missing in the TF XGLM modeling file as follows:

_keys_to_ignore_on_load_missing = [r"h.\d+.attn.masked_bias", r"h.\d+.attn.bias", r"lm_head.weight"]
Let us know if you have some questions ! 🙏 cc also @Rocketknight1 how do we properly deal with non-persistent tensors in TF?

@Rocketknight1
Copy link
Member

Rocketknight1 commented Aug 28, 2023

@younesbelkada Good question! There is no register_buffer method in TF. There are two replacements you can use for it, depending on what the variable is actually doing:

  1. If the variable needs to be saved/loaded with the model, but you just don't want the optimizer to train it, then create it with self.add_weight(trainable=False). This is similar to self.register_buffer(persistent=True)
  2. If the weights don't need to be saved, and are just created as a performance optimization to avoid recomputing them in every iteration, this is similar to self.register_buffer(persistent=False). In this case, you can create them in the layer __init__() or build() method as a tf.constant. They will only be computed once, and marking them as tf.constant lets the compiler do constant optimizations in the graph. If you do it this way then TF won't really treat them like a 'weight' at all, so you'll probably have to add them to _keys_to_ignore_on_load_unexpected if they exist as a weight in the PyTorch model.

Let me know if you need my help writing a PR for any of this!

@MattYoon
Copy link
Contributor Author

Hi, @younesbelkada. Thanks for guiding me through the PR! To be frank, I'm not familiar enough on either TF or Transformers to complete this PR. I'm worried that me attempting to fix this issue will cause some other problems and merging this PR will take way longer than necessary.

The issue seems like a very simple fix for someone familiar with the internals of Transformers. Can you or someone else close this PR and take the torch? Sorry I couldn't be much help.

@Rocketknight1
Copy link
Member

Rocketknight1 commented Aug 29, 2023

@MattYoon You don't need to close it! If you allow edits from maintainers, I can push the relevant change to your branch. It should only be one line in the TF code. Are you okay with me doing that?

@MattYoon
Copy link
Contributor Author

Yes that sounds great! I believe "allow edits from maintainers" is active for this PR.

@Rocketknight1
Copy link
Member

@MattYoon Done! I also fixed some spelling in the TF module while I was there.

cc @younesbelkada too

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your great contribution @MattYoon and thanks a lot for the help @Rocketknight1 !

@Rocketknight1
Copy link
Member

cc @amyeroberts for core maintainer review!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@amyeroberts amyeroberts merged commit 2ee60b7 into huggingface:main Aug 29, 2023
21 checks passed
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
…e#25798)

* fix warning triggering for xglm.embed_positions

* Make TF variable a tf.constant to match (and fix some spelling)

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
…e#25798)

* fix warning triggering for xglm.embed_positions

* Make TF variable a tf.constant to match (and fix some spelling)

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
…e#25798)

* fix warning triggering for xglm.embed_positions

* Make TF variable a tf.constant to match (and fix some spelling)

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Loading XGLM triggers warning "Some weights of .. were not initialized from" for a module with no params
5 participants