-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPMD] Introduce high level manual sharding APIs #6931
Conversation
*, | ||
mesh: Mesh = None) -> XLAShardedTensor: | ||
""" | ||
This API enables manual sharding for the given tensor. Manual sharding disables auto sharding proporgation and auto |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"auto" --> "SPMD", think it's important to not confuse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, left a comment for comment :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
""" | ||
mesh = get_global_mesh() if mesh is None else mesh | ||
t = mark_sharding(unwrap_sharded_tensor(t), mesh, partition_spec) | ||
t = torch_xla._XLAC._spmd_full_to_shard_shape(unwrap_sharded_tensor(t)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can t
here be DeviceData?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean the input? Yes!
""" | ||
This API enables manual sharding for the given tensor. Manual sharding disables auto sharding proporgation and auto | ||
partition for the given tensor and all subsequential tensors that produced by an op that uses the given tensor as | ||
input, and therefore allows the user to manually call collectives for the tensor and subsequential tensors. It |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also just curious - how will we enable collectives in a manual region?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XLA cc ops by default should work. Just use it as normal. However, we need to teach our cc ops wrapper to be aware of SPMD mode. So, it will be phase 2 of the mnual sharding.
Summary: This pull request introduces: 1. enable_manual_sharding: which starts the manual sharding region. 2. disable_manual_sharding: which disable the manual sharding region. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_api_e2e
Summary: This pull request introduces: 1. enable_manual_sharding: which starts the manual sharding region. 2. disable_manual_sharding: which disable the manual sharding region. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_api_e2e
Summary:
This pull request introduces:
Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_api_e2e