Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

params partition for skip_init #4722

Merged
merged 8 commits into from
Jan 18, 2024
Merged

Conversation

inkcherry
Copy link
Contributor

Some models use skip_init to initialize weights. skip_init first initializes on a meta device in __init__ of a module and then uses to_empty(). This conflicts with the deepspeed hook module.__init__ mechanism. it's necessary to wait for skip_init to finish before executing _post_init_method. However, the from ... import skip_init behavior typically occurs outside the context, there seems to be no good way to directly hook into skip_init. Hence, the approach here is to delay the execution of _post_init_method to resolve this issue.
Known affected models include HuggingFace models like chatglm2 and chatglm3."

@tjruwase tjruwase removed the request for review from jeffra January 5, 2024 20:57
@tjruwase tjruwase requested review from tohtana and removed request for mrwyattii and loadams January 5, 2024 20:58
@tohtana
Copy link
Contributor

tohtana commented Jan 8, 2024

Hi @inkcherry,
Just for clarification, it seems that _post_init_method runs after the __init__ of the top-level module. Is this correct?
If so, can we reduce the peak memory footprint of parameters on initialization?

@inkcherry
Copy link
Contributor Author

inkcherry commented Jan 9, 2024

Hi @inkcherry, Just for clarification, it seems that _post_init_method runs after the __init__ of the top-level module. Is this correct? If so, can we reduce the peak memory footprint of parameters on initialization?

Thanks for the review. @tohtana
It prioritizing the next non-meta module initialization completion(prioritizing a peer module below, and if None, its container module ).
I think this is ok because the module order of entering and exiting the post_init_module remains unchanged, but anyway. peak memory usage might be higher because module(which may contains several layers) of skip_init call need to be fully placed on one rank and then partition. To reduce memory(after child layer initialization is completed on real device, partition them), I think some hook and restore is necessary during the skip_init lifespan. I think I could try to maintain the code within a condition scope . What do you think.

@tohtana
Copy link
Contributor

tohtana commented Jan 9, 2024

Thank you for the clarification, @inkcherry.

I am wondering what happens if almost all modules are declared with skip_init. In this case, partitioning with zero.Init() won't work and we need the host memory of the size |all parameters| * |number of local GPUs (processes)| on a server, right?

I understand that it is difficult to set a hook in skip_init. But can we set one after skip_init as another approach?

@inkcherry
Copy link
Contributor Author

inkcherry commented Jan 12, 2024

@tohtana , Thanks for your suggestion!
Yes, if the device is not set to a GPU, the host memory behaves like this.
I concern that adding logic after skip_init may also result in high memory usage(Before your reminder, I might have been more focused on functionality), For example, if this module includes a ModuleList of all Transformer blocks, essentially containing almost all parameters of the model, we also need num_processes * all_transformer_blocks_parameters memory on host with to_empty() call.

currently, I have implemented a hook, without encountering such memory issues[after the child module initialization is completed by _apply(empty_like) in to_empty, split it], and it provides better functionality. I decoupled it into a separate function.

I tested full-parameters finetuning with chatglm2 6B zero3, without using skip_init and this patch, and compared it with using the default skip_init of the model with this patch. The loss is exactly the same for the first 50 steps.

@inkcherry
Copy link
Contributor Author

inkcherry commented Jan 16, 2024

@tohtana I just make some changes ,could you please take a look, thanks!

@tohtana
Copy link
Contributor

tohtana commented Jan 16, 2024

Thank you @inkcherry, the hook you implemented should work. This is an intricate and refined work!
It is good to me to merge after this change passes the tests.

@inkcherry
Copy link
Contributor Author

inkcherry commented Jan 18, 2024

@tohtana The CI has all passed. Just a reminder in case you missed it。
also, thanks for the internal help of @delock @guoyejun

@tjruwase tjruwase added this pull request to the merge queue Jan 18, 2024
Merged via the queue into microsoft:master with commit 3110c38 Jan 18, 2024
12 checks passed
@tohtana
Copy link
Contributor

tohtana commented Jan 18, 2024

@inkcherry This PR was merged. Thank you for your great contribution!

mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
Some models use ```skip_init``` to initialize weights. ```skip_init```
first initializes on a meta device in ```__init__``` of a module and
then uses ```to_empty()```. This conflicts with the deepspeed hook
```module.__init__``` mechanism. it's necessary to wait for
```skip_init``` to finish before executing ```_post_init_method```.
However, the ```from ... import skip_init``` behavior typically occurs
outside the context, there seems to be no good way to directly hook into
```skip_init```. Hence, the approach here is to delay the execution of
```_post_init_method``` to resolve this issue.
Known affected models include HuggingFace models like chatglm2 and
chatglm3."

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants