Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to new meta device init method #65

Merged
merged 1 commit into from
Mar 28, 2024
Merged

switch to new meta device init method #65

merged 1 commit into from
Mar 28, 2024

Conversation

lchu-ibm
Copy link
Contributor

@lchu-ibm lchu-ibm commented Mar 27, 2024

This PR makes a new and improved version of model init which is more efficient.

Full details: #64

Performance on model init time (model creation + fsdp wrapping)

We have benchmarked the model init performance on 16 aws p4de nodes:

old implementation new implementation
7b 100s 0.2s
13b 193s 0.3s
34b 550s 0.6s
70b 1146s 1.5s

Validation test

Since unittest is hard to capture true-multi-node runs, we have ran external tests to make sure the new implementation pass validation test and yields true init. Specifically, we added the following plugin after FSDP call for testing purpose

with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)):
    model_state = model.state_dict()

if rank == 0:
    model2 = LLaMA(llama_config)
    model2.load_state_dict(model_state)

    model2.check_parameters()

and check_parameters() is defined here

@lchu-ibm lchu-ibm self-assigned this Mar 27, 2024
@nairbv
Copy link
Contributor

nairbv commented Mar 27, 2024

it looks like this is depending on code that's still on a branch? should this PR be a draft?

@lchu-ibm
Copy link
Contributor Author

@nairbv it does not. that extra piece of code is just for a one-time regression/validation test that performed externally and will not be repeated. we might re-run the same test next time if we modify this but this will likely be a good final version.

@lchu-ibm lchu-ibm requested a review from nairbv March 27, 2024 15:14
@lchu-ibm
Copy link
Contributor Author

Although - that function we added as a one-time thing might be potentially useful to add in FMS too @daviswer . And it does not depend on anything else. If we do so, @daviswer just remember to check in your first version (the one I linked above, without the later "wrong fix" like adding .data)

@daviswer
Copy link
Collaborator

@lchu-ibm Sure, just add it in that format to this branch?

@lchu-ibm
Copy link
Contributor Author

@daviswer yes, maybe rename it as validate_parameters()

@lchu-ibm lchu-ibm requested a review from daviswer March 27, 2024 19:13
daviswer added a commit to foundation-model-stack/foundation-model-stack that referenced this pull request Mar 28, 2024
See foundation-model-stack/fms-fsdp#65

Allows us to verify that FSDP-linked model initialization is working as intended in distributed pretraining contexts

This is currently implemented solely for Llama.
@lchu-ibm lchu-ibm merged commit a2d51ac into main Mar 28, 2024
3 checks passed
@lchu-ibm lchu-ibm deleted the meta_init branch March 28, 2024 18:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants