Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add manual sharding API for SPMD #2

Merged
merged 4 commits into from
Aug 6, 2024
Merged

Add manual sharding API for SPMD #2

merged 4 commits into from
Aug 6, 2024

Conversation

lausannel
Copy link

@lausannel lausannel commented Aug 6, 2024

Summary:
This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly.

Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding
Summary:
This pull request supports SPMDFullToShardShape which is a custom op that opens a region for non-partitioned graph in SPMD program. It will stop SPMD auto sharding and partition in that region and therefore allows manual sharding like cc ops.

To implement it, this pull request expands CustomSharding node to accept a new type. To be notice, the output shape of the op needs to be the shard shape of the input, and the node needs to have manual sharding annotation.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_spmd_full_to_shard_shape
Summary:
This pull request enables SPMDShardToFullShape. The trickiest part is how to get the full shape, and here is a couple of options:
1. Bookkeeping the shape full shape that enters SPMDFullToShardShape. This is not selected given the output could be created on the fly.
2. Constructing the full shape from the local shard and the sharding spec. This is not selected given there is no way to deal with the padding. We can't examine the data during the tracing time.
3. Let users pass the full shape in. This is selected because it's just the most sounded path.

Tes Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_e2e -k test_spmd_shard_to_full_shape
Summary:
This pull request introduces:
1. enable_manual_sharding: which starts the manual sharding region.
2. disable_manual_sharding: which disable the manual sharding region.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_api_e2e
@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@baoleai baoleai merged commit 057cec2 into acc Aug 6, 2024
1 of 2 checks passed
@baoleai baoleai deleted the add_manual_sharding branch August 6, 2024 09:46
yitongh pushed a commit that referenced this pull request Aug 8, 2024
Support Disc as backend
Co-authored-by: yancey.yx <yancey.yx@antfin.com>
Co-authored-by: wangang.wa <wangang.wa@alibaba-inc.com>
yitongh pushed a commit that referenced this pull request Aug 8, 2024
Support Disc as backend
Co-authored-by: yancey.yx <yancey.yx@antfin.com>
Co-authored-by: wangang.wa <wangang.wa@alibaba-inc.com>
anw90 pushed a commit that referenced this pull request Oct 11, 2024
* build with BladeDISC (#8)

* [to #53687860] feat: DISC client header, implement DISCComputation and DISCData

POC implement in : https://code.alibaba-inc.com/torchx/xla/codereview/14984824

Link: https://code.alibaba-inc.com/torchx/xla/codereview/14987956

* Disc computation (#2)

Support Disc as backend
Co-authored-by: yancey.yx <yancey.yx@antfin.com>
Co-authored-by: wangang.wa <wangang.wa@alibaba-inc.com>

* add bazel flag to disable disc backend (#23)

* add flag to disable disc backend in bazel workspace

* support disc debug mode to dump mhlo and logs (#25)

support disc backend debug mode to dump DISC compilation logs

* support flash attention in disc (pytorch#34)

* fix disc flag when complie python (pytorch#39)

* fix bazel flag when complie python

* fix lint.

* support bf16 on disc backend (pytorch#40)

add float-norm pass to support bf16 amp training

* Support Flash Attention 2.5.6 for disc backend (#4)

* fix build failed with NCCL (#5)

* fix build failed on nccl

* using nccl hdrs

* Use the value of DISC_DEVICE as the device type of disc backend (#8)

* change the device type of disc to cuda to make amp work properly

* Use the value of DISC_DEVICE as the device type of disc backend

* disable compilation of DISC by default (#15)

---------

Co-authored-by: Yan Xu <yancey1989@gmail.com>
Co-authored-by: wenting.swt <wenting.swt@alibaba-inc.com>
Co-authored-by: Dalong <yuanxiulong.yxl@alibaba-inc.com>
Co-authored-by: Baole Ai <baole.abl@alibaba-inc.com>
Co-authored-by: Yan Xu <yancey.yx@alibaba-inc.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants