Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Introduce high level manual sharding APIs #6931

Merged
merged 5 commits into from
Apr 17, 2024

Conversation

alanwaketan
Copy link
Collaborator

Summary:
This pull request introduces:

  1. enable_manual_sharding: which starts the manual sharding region.
  2. disable_manual_sharding: which disable the manual sharding region.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_api_e2e

@alanwaketan alanwaketan requested review from yeounoh and jonb377 April 17, 2024 00:46
@alanwaketan alanwaketan self-assigned this Apr 17, 2024
*,
mesh: Mesh = None) -> XLAShardedTensor:
"""
This API enables manual sharding for the given tensor. Manual sharding disables auto sharding proporgation and auto
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"auto" --> "SPMD", think it's important to not confuse.

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, left a comment for comment :)

Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

"""
mesh = get_global_mesh() if mesh is None else mesh
t = mark_sharding(unwrap_sharded_tensor(t), mesh, partition_spec)
t = torch_xla._XLAC._spmd_full_to_shard_shape(unwrap_sharded_tensor(t))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can t here be DeviceData?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the input? Yes!

"""
This API enables manual sharding for the given tensor. Manual sharding disables auto sharding proporgation and auto
partition for the given tensor and all subsequential tensors that produced by an op that uses the given tensor as
input, and therefore allows the user to manually call collectives for the tensor and subsequential tensors. It
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just curious - how will we enable collectives in a manual region?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLA cc ops by default should work. Just use it as normal. However, we need to teach our cc ops wrapper to be aware of SPMD mode. So, it will be phase 2 of the mnual sharding.

@alanwaketan alanwaketan merged commit 9b2ac4b into master Apr 17, 2024
4 checks passed
@alanwaketan alanwaketan deleted the alanwaketan/manual_sharding_api branch April 17, 2024 18:28
lausannel pushed a commit to AlibabaPAI/xla that referenced this pull request Aug 6, 2024
Summary:
This pull request introduces:
1. enable_manual_sharding: which starts the manual sharding region.
2. disable_manual_sharding: which disable the manual sharding region.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_api_e2e
baoleai pushed a commit to AlibabaPAI/xla that referenced this pull request Aug 6, 2024
Summary:
This pull request introduces:
1. enable_manual_sharding: which starts the manual sharding region.
2. disable_manual_sharding: which disable the manual sharding region.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_api_e2e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants