-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDPv2] Shard on the maximal dim of weights #7134
Conversation
@@ -113,7 +118,8 @@ def __init__( | |||
for param in module.parameters(): | |||
if torch_xla._XLAC._get_xla_sharding_spec(param) != "": | |||
continue | |||
spmd.mark_sharding(param, mesh, _prepare_spmd_partition_spec(param)) | |||
spmd.mark_sharding( | |||
param, mesh, _prepare_spmd_partition_spec(param, shard_maximal=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we make it configureable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, not necessary... It shouldn't matter to the user...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I think I get it.. We need to do a all-gather anyway before entering the layer and only one dimension is being sharded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, that's right. If the 0th dim is like size of 8 (MoE) and we are sharding it on v5p-2048, it will be a disaster.
701277f
to
9dd39cc
Compare
Thanks Jack for approving the change. |
Skip GPU to move fast. |
Summary:
This pull request makes FSDPv2 to shard on the maximal dim of weights instead of the 0th dim.
Test Plan:
XLA_USE_SPMD=1 PJRT_DEVICE=TPU python test/spmd/test_fsdp_v2.py