Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove repetitive entries from device lists #1321

Merged
merged 1 commit into from
Apr 17, 2023

Conversation

xloem
Copy link
Contributor

@xloem xloem commented Apr 17, 2023

I found when loading that each tensor for a device was reloaded for every layer in an n**2 manner.

It seemed the device list in load_state_dict was loaded as device map keys, but then treated as a unique set despite having an element for every entry in the map, i.e. each device was in the list many times over.

I searched the file for devices, found two uses of lists as if they were sets, and added code to remove duplicate entries from both instances. I used a dict rather than a set so that order of entries is preserved, keeping the first device as the primary one.

EDIT: the other list I found already had unique entries and that change has been removed.

With this change, for me, the loading code no longer reloads every tensor for each map entry, instead loading them once for each device.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 17, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please post a small reproducer of the code that failed before this PR?

@xloem
Copy link
Contributor Author

xloem commented Apr 17, 2023

For me, the below code runs both timed tests in about 2 seconds with this PR, whereas before it the second test hangs and ends up measuring 1200 seconds:

import timeit
import accelerate, safetensors.torch, torch

import warnings; warnings.filterwarnings('ignore')

DEV = 3

keys = [f'{letter}{number}' for letter in bytes(list(range(ord('A'), ord('z')))).decode() for number in range(10)]

state_dict = {
        key : torch.rand(1024) for key in keys
}
device_map = {
        key : DEV for key in keys
}
safetensors.torch.save_file(state_dict, 'test.safetensors', metadata=dict(format='pt'))
print('Loading state_dict with safetensors.torch.load_file:')
def timed_reference_call():
    safetensors.torch.load_file('test.safetensors', device=DEV)
result = timeit.timeit(timed_reference_call, setup = timed_reference_call, number=10)
print(result)
print('Loading state_dict with accelerate.utils.modeling.load_state_dict:')
def timed_test_call():
    accelerate.utils.modeling.load_state_dict('test.safetensors', device_map)
result = timeit.timeit(timed_test_call, setup = timed_test_call, number=10)
print(result)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explainig. I didn't understand your change at first, let's use a set instead of going through the keys of an artificial dict.

src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
@xloem xloem force-pushed the load-each-device-once branch 2 times, most recently from 402de02 to c8539a3 Compare April 17, 2023 19:20
@sgugger
Copy link
Collaborator

sgugger commented Apr 17, 2023

Mmm I don't know what's happening at GitHub right nw, but I can't authorize the tests to run. Will wait for a bit and hopefully the button will come back in a couple of hours. Stay tuned!

Previously devices() was a list containing duplicate entries. This
changes it into a set.

This significantly speeds safetensors loading when the device map is
long, as the safetensors loop loads each weight entry for each device
entry.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@xloem xloem force-pushed the load-each-device-once branch from c8539a3 to 0b7c941 Compare April 17, 2023 19:51
@sgugger sgugger merged commit 5e63515 into huggingface:main Apr 17, 2023
@sgugger
Copy link
Collaborator

sgugger commented Apr 17, 2023

All green, thanks again for your contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants