-
Notifications
You must be signed in to change notification settings - Fork 23.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] XLA Lazy Backend Support In DistributedTensor API #92909
Labels
module: xla
Related to XLA support
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Comments
cc @ronghanghu |
pytorchmergebot
pushed a commit
that referenced
this issue
Oct 21, 2023
This addresses #92909 , and enable XLA backend support for `distribute_tensor` API. Test plan: added a unit test case & tested with CloudTPU. The CI should skip this unless it's a XLA workflow. Pull Request resolved: #110275 Approved by: https://github.com/wanchaol, https://github.com/alanwaketan, https://github.com/JackCaoG
xuhancn
pushed a commit
to xuhancn/pytorch
that referenced
this issue
Nov 7, 2023
…110275) This addresses pytorch#92909 , and enable XLA backend support for `distribute_tensor` API. Test plan: added a unit test case & tested with CloudTPU. The CI should skip this unless it's a XLA workflow. Pull Request resolved: pytorch#110275 Approved by: https://github.com/wanchaol, https://github.com/alanwaketan, https://github.com/JackCaoG
This was referenced Nov 7, 2023
Skylion007
pushed a commit
to Skylion007/pytorch
that referenced
this issue
Nov 14, 2023
…110275) This addresses pytorch#92909 , and enable XLA backend support for `distribute_tensor` API. Test plan: added a unit test case & tested with CloudTPU. The CI should skip this unless it's a XLA workflow. Pull Request resolved: pytorch#110275 Approved by: https://github.com/wanchaol, https://github.com/alanwaketan, https://github.com/JackCaoG
This was referenced Mar 6, 2024
pytorchmergebot
pushed a commit
that referenced
this issue
Mar 7, 2024
In response to the change pytorch/xla#5776 and #92909 Pull Request resolved: #113214 Approved by: https://github.com/wanchaol
pytorchmergebot
pushed a commit
that referenced
this issue
Mar 8, 2024
Addresses #92909 cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang Pull Request resolved: #121355 Approved by: https://github.com/wanchaol
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
module: xla
Related to XLA support
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🚀 The feature, motivation and pitch
TL;DR
The proposed DistributedTensor provides a new abstraction to express tensor distributions with both sharding and replication parallelism strategies in eager mode and non-lazy backends, like
cuda
. We propose to integrate XLAShardedTensor and mark_sharding API integration forxla
lazy-backend support in the DistributedTensor API. Our goal is to allow PyTorch users to shard a big tensor acrossxla
devices with just a few lines of code:This example is from the DistributedTensor [RFC], with a main difference being the device type
xla
.Motivation
The proposed DistributedTensor APIs (distribute_tensor, distribute_module) allow the user to express various types of tensor distributions with just a few lines of code. While simple and generic enough to express many common parallelism paradigms, its current support for backend devices does not entail lazy backends, like
xla
. PyTorch/XLA offers a set of lower-level XLAShardedTensor APIs that exposes sharding annotations for the tensors residing on thexla
devices. Both DistributedTensor and XLAShardedTensor support sharding and replication parallelism strategies, defined by a logical device mesh and a sharding placement spec. Here, we propose to integrate the low-level XLAShardedTensor APIs into the high-level DistributedTensor APIs, so that a user can use the same set of DistributedTensor APIs to express tensor distributions with both sharding and replication parallelism strategies.Pitch
We integrate
xla
backend specific XLAShardedTensor APIs into the high-level DistributedTensor APIs (distribute_tensor, distribute_module) so the user can use the same DistributedTensor APIs to express tensor distributions (sharding or replication) on CPU, GPU andxla
backend devices, like TPU. Some restrictions apply to the tensor distributions onxla
backend: partial tensor distribution is only available in DistributedTensor native backends, as the strategy is “forbidden from constructor” and only used for the intermediary results (tensors); XLAShardedTensor APIs may propagates sharding and currently assume a fixed device assignments to the logical mesh; the output tensor(s) is replicated unless sharded explicitly by the user.The call to the high-level DistributedTensor API can easily be translated into the low-level XLAShardedTensor API based on the following conversions:
Conversions
DeviceMesh <> Mesh
DistributedTensor API (e.g., distribute_tensor(...)) works with a device mesh, declared by a DeviceMesh instance. It is a subclass of torch.device, and describes the sharded tensor or module placements. For instance, the following mesh defines a 1-by-4 logical device mesh:
The first argument is the device type, “xla”, and the mesh is described by the list of logical device IDs (global rank), [0, 1, 2, 3], which implies a single host (per row) with 4 devices. If the mesh is defined with “xla”, then the DistributedTensor API can call the XLAShardedTensor API with the same mesh topology with a shape (1, 4):
The conversion from DistributedTensor DeviceMesh to XLAShardedTensor Mesh is straightforward:
We can also define DeviceMeshBase for some common properties and interface between DeviceMesh and Mesh:
List[Placement] <> Tuple[int, None]
One can convert the DistributedTensor placement specs into the XLAShardedTensor partitioning specs by mapping the “per mesh dimension sharding” (DistributedTensor) to the “per tensor dimension sharding” (XLAShardedTensor). For an illustration, consider an input tensor of shape (4, 8, 8) and its sharding across a (2, 4) device mesh. Namely, the first tensor dimension will be sharded 4-way across the second dimension of the device mesh, and the rest will be replicated.
In DistributedTensor, this is expressed with a placement spec, [Replicate(), Shard(0)] where each of the spec elements describes how the corresponding mesh dimension will be used, replicated or sharded. Finally, Shard(0) means that the first dimension of the input tensor (index 0) will be sharded, in this case over the second dimension of the mesh.
In XLAShardedTensor, the same sharding strategy is denoted by a partition spec, (1, None, None). Each spec element describes how the corresponding input tensor dimension will be mapped to the device mesh. For example, partition_spec[0] = 1 indicates that the first dimension of the input tensor will be mapped to the second dimension (index 1) of the device mesh, thus split 4-way. None means replication, and the rest of the input dimensions will be replicated.
Note that the XLAShardedTensor uses a different sharding spec representation, where a sharding strategy is declared “per tensor dimension”. We can transform DT placement specs (Shard or Replicate) into partition specs,
DistributedTensor <> XLAShardedTensor
Tensor distributions on the
xla
backend triggers the XLA compiler to partition and propagates the sharding, the final result is the same as if the computation were not sharded, and the result is replicated across the devices. This is the side-effects of thexla
backend tensor distribution. One can avoid such side-effects and just apply torch ops to the sharded tensors, by taking the returned XLAShardedTensor and converting it to DistributedTensor. This conversion requires that the DistributedTensor resides on the CPU.DistributedTensor API with
xla
devicedistribute_tensor
Calling distribute_tensor with an
xla
device_mesh will trigger a mark_sharding API call with the transformed input arguments:The distribute_tensor API returns a torch.Tensor that can be either DistributedTensor or XLAShardedTensor.
distributed_module
This API is currently mainly used for manual sharding specification, not like GSPMD automatic style sharding propagation, i.e. it allows the user to specify sharding, and treat the rest of the module parameters as replicated. Currently we are in the process of deciding if we want to use this API or a new API to do GSPMD style sharding propagation. We can revisit this with XLA GSPMD integration later if we settled in the API.
Alternatives
We want to make the DistributedTensor API to be device agnostic and also support the
xla
lazy backend. PyTorch/XLA provides a set of lower-level APIs which can be integrated into DT to support the distributed tensor execution on the lazy backend, with some limitations. The goal is to promote more consistent user experiences across different backends, and use the same abstraction as possible. An alternative is to integrate into other distributed tensor abstractions and their APIs, which we may consider after integrating with DT first, if need to.cc @bdhirsh @wanchaol @JackCaoG @steventk-g @fduwjj @alanwaketan @miladm
The text was updated successfully, but these errors were encountered: