-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove repetitive entries from device lists #1321
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please post a small reproducer of the code that failed before this PR?
For me, the below code runs both timed tests in about 2 seconds with this PR, whereas before it the second test hangs and ends up measuring 1200 seconds:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explainig. I didn't understand your change at first, let's use a set instead of going through the keys of an artificial dict.
402de02
to
c8539a3
Compare
Mmm I don't know what's happening at GitHub right nw, but I can't authorize the tests to run. Will wait for a bit and hopefully the button will come back in a couple of hours. Stay tuned! |
Previously devices() was a list containing duplicate entries. This changes it into a set. This significantly speeds safetensors loading when the device map is long, as the safetensors loop loads each weight entry for each device entry. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
c8539a3
to
0b7c941
Compare
All green, thanks again for your contribution! |
I found when loading that each tensor for a device was reloaded for every layer in an
n**2
manner.It seemed the device list in
load_state_dict
was loaded as device map keys, but then treated as a unique set despite having an element for every entry in the map, i.e. each device was in the list many times over.I searched the file for
devices
, found two uses of lists as if they were sets, and added code to remove duplicate entries from both instances. I used a dict rather than a set so that order of entries is preserved, keeping the first device as the primary one.EDIT: the other list I found already had unique entries and that change has been removed.
With this change, for me, the loading code no longer reloads every tensor for each map entry, instead loading them once for each device.