Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct loading of models with shared tensors when using accelerator.load_state() #2875

Merged
merged 5 commits into from
Jul 15, 2024

Conversation

jkuntzer
Copy link
Contributor

What does this PR do?

I would run into problems with PyTorch's load_state_dict complaining about missing keys. These keys belonged to shared tensors. These shared keys are intentionally omitted by the safetensors library. To load a model correctly, one has to use safetensor's load_model function instead of the default load_state_dict function (described here). This was previously not done when using the load_state function of the Accelerator.

Fixes # (issue)
I think this issue might be relevant as they also report problems when loading with accelerator.load_state.
#2155

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just one question

Comment on lines -204 to 205
models[i].load_state_dict(state_dict, **load_model_func_kwargs)
model.load_state_dict(state_dict, **load_model_func_kwargs)
logger.info("All model weights loaded successfully")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason for this change? I'd expect only the prior to be modified.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the if statement, he's loading the safetensors model directly whereas before, we were only getting the state dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_model does both: it loads the file and uses it to populate the state_dict. Previously, each branch of the if-condition only loaded the file and after the if-condition, the model would load the state dict. Since load_model does both, I indented the statement on line 204 to become part of the else-clause. This becomes clearer when you have a look at the complete surroundings of the changes instead of only the affected lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other change in this line (aside from the indent) namely using model instead of models[i] is mostly cosmetic. My linter was complaining that the enumerate call defines model but it's never used.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change and spotting the issue ! Could you add a test with a model with tied weight ? You can use the following test for reference : test_save_load_model

@jkuntzer
Copy link
Contributor Author

jkuntzer commented Jul 5, 2024

Yes, I'll have a look into it.

@jkuntzer
Copy link
Contributor Author

jkuntzer commented Jul 9, 2024

You can verify that the shared weights are implemented correctly by checking the output. safetensors warns you about that fact.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating and adding the tests @jkuntzer ! Could you do a final check and see if the test that you added fails when you remove the changes you did ?

@jkuntzer
Copy link
Contributor Author

jkuntzer commented Jul 9, 2024

Just did. This is the expected error message I get when reverting my changes.
Screenshot from 2024-07-09 15-15-51

Comment on lines 42 to 44
# need to add this for compliance with other methods
self.weight = self.linear1.weight
self.bias = self.linear1.bias
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need that ? where does it fail ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It used to fail previously. You're right. This part can be safely removed.

@SunMarc
Copy link
Member

SunMarc commented Jul 9, 2024

Just did. This is the expected error message I get when reverting my changes.

I was only expecting linear2.weight and linear2.bias to be missing. Maybe this is due to

self.weight = self.linear1.weight
self.bias = self.linear1.bias

@jkuntzer
Copy link
Contributor Author

jkuntzer commented Jul 9, 2024

Just did. This is the expected error message I get when reverting my changes.

I was only expecting linear2.weight and linear2.bias to be missing. Maybe this is due to

self.weight = self.linear1.weight
self.bias = self.linear1.bias

After removing the unnecessary bits, it correctly only throws an error for the weights and bias of the 2nd linear layer.
Screenshot from 2024-07-09 16-41-56

@SunMarc
Copy link
Member

SunMarc commented Jul 10, 2024

Nice ! Could you just fix the quality issue (make style) and we are good to merge !

@SunMarc SunMarc requested a review from muellerzr July 10, 2024 14:18
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@muellerzr muellerzr merged commit f4f1260 into huggingface:main Jul 15, 2024
24 of 25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants