-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix load_state_dict
for xpu and refine xpu safetensor version check
#2879
Conversation
@SunMarc this PR is ready for review. Could you help review it? Thx a lot! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Thanks for fixing @faaany !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@faaany small merge conflict (the joys of day-of-release-merging-post-OOO), if you can fix that I'll get this in and it'll be part of the release today 🚀 |
awesome, thx! conflict resolved. |
expected_device = ( | ||
torch.device(f"{torch_device}:{device}") if isinstance(device, int) else torch.device(device) | ||
) | ||
assert loaded_state_dict[param].device == expected_device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI this breaks CUDA tests because on CUDA we end up with cuda:0:0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
Please check out the new failure on main @faaany, thanks! :) (Test is failing bc of the aforementioned note above) |
What does this PR do?
test_load_state_dict
on xpu, becausetorch.device(0)
by default is cudaload_state_dict
whendevice_map
's values are identical integers