You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We propose XLAShardedTensor to represent a sharded tensor that wraps around torch.Tensor, and mark_sharding() API for tensor sharding annotation. XLAShardedTensor allows annotating tensors with sharding specs and dispatching the annotations to the XLA backend for XLA GSPMD support in PyTorch/XLA.
Usage Example
importtorchimporttorch_xla.core.xla_modelasxmimporttorch_xla.distributed.xla_shardingasxsfromtorch_xla.distributed.xla_shardingimportMeshmesh_shape= (2, 4) # device mesh num_devices=len(xm.get_xla_supported_devices())
device_ids=np.array(range(num_devices))
mesh=Mesh(device_ids, mesh_shape, ('x', 'y'))
t=torch.randn(8, 4).to(xm.xla_device())
# Mesh partitioning, each device holds 1/8-th of the inputpartition_spec= (0, 1)
m1_sharded=xs.mark_sharding(t, mesh, partition_spec)
# XLAShardedTensor behaves like a unpartitioned native tensor:# - we can apply native torch.Tensor ops and nn.layers# - print the unpartitioned tensor when sent to CPU# - provides access to the list of local shards after lazy executionassertisinstance(m1_sharded, XLAShardedTensor) ==True
Motivation
Our goal is to support GSPMD model sharding in PyTorch -- this would allow a user to bring in PyTorch models implemented (as if) on a single device and annotate a few tensors with desired sharding specs to run efficient model parallelism. Such an automated approach to model sharding allows the XLA compiler to optimize the entire computation graph end-to-end and frees up the user from implementing sharded version of ops with proper collectives in place.
The current PyTorch ShardedTensor abstraction RFC provides simple primitives to express sharded tensor also, and the ShardedTensor API provides a set of convenience helper functions to shard tensor or model parameters with sharding specs. Under the hood, it requires manual/explicit implementation of sharded ops (e.g., sharded version of torch.nn.functional.linear) and careful injection of collective comms. And the abstraction represents sharded tensors directly, not the sharding annotation for XLA compiler-based sharding that will take place lazily and support the xla backend.
Pitch
To enable our XLA compiler-based sharding, we propose XLAShardedTensor and mark_sharding API. In this section, we also describe how user can specify different tensor sharding strategies for the sharding annotation.
XLAShardedTensor
The main use case for XLAShardedTensor is to annotate a native torch.tensor (on a single device) with sharding spec. The annotation takes place immediately, but the actual sharding of the tensor happens lazily. Once a tensor is annotated and wrapped inside a XLAShardedTensor, it can be passed to existing PyTorch ops and nn.Module layers as torch.Tensor. This is critical to ensure that layers and tensor ops can be stacked together as before, which means that the user does not need to rewrite the existing single device model for sharded computation. Namely, XLAShardedTensor will satisfy the following requirements:
furthermore, XLAShardedTensor as a torch.Tensor subclass should work directly with native torch ops and module.layers. We use __torch_dispatch__ to send XLAShardedTensor to the XLA backend, and PyTorch/XLA should be able to retrieve attached sharding annotations to trace the graph with them and invoke SPMDPartitioner.
the handles to the local shards are materialized strictly after the lazy execution.
the local shards (or replicas) are gathered and materialized to CPU when accessed after lazy execution.
@dataclassclassXLAShard:
data: torch.Tensorrank: intclassXLAShardedTensor(torch.Tensor):
""" A wrapper around `torch.Tensor` with sharding annotation for XLA SPMD auto-sharding. The wrapped tensors are unwrapped for IR tracing and converted to HLO graph with sharding annotations; XLA SPMDPartitioner takes a pass, propagating and injecting collectives to the graph before compilation. """# XLAShardedTensor behaves like a unpartitioned,# combined tensor on the host machine. When user annotates,# this is simply set to the input tensor. When an XLA partitioned# output tensor returns (or sharding propagated intermediate tensors)# as XLAShardedTensor, the backend gathers global data across devices# and materialize and set `global_tensor` on the host; the actual device# data still remain on individual device as sharded or replicated.# Note: we should drop this reference, and force all gather on each access.global_tensor: torch.Tensor# Shards on the devices are materialized/available after the lazy# execution of the SPMDPartitioned HLO graph; otherwise,# local_shards is set to `None`. Each XLAShard points to# torch.Tensor (xla::device_data).# Note: we can consider returning a callback or even define# sharding at XLAShardedTensor construction after pjrt migration.local_shards: List[XLAShard] =None__slots__= ['global_tensor']
@staticmethoddef__new__(cls, elem: torch.Tensor, *args, **kwargs):
r=torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=kwargs.get("requires_grad", False))
r.global_tensor=elem.detach() ifr.requires_gradelseelemreturnr@propertydefsharding_spec(self):
returnNotImplemented@propertydefshards(self):
returnNotImplementeddef__repr__(self):
returnf"XLAShardedTensor({self.global_tensor})"@classmethoddef__torch_dispatch__(cls, func, types, args=(), kwargs=None):
""" The dispatcher allows the unwrapped torch.Tensor to re-dispatched to the `xla` backend as XlaTensor, and the XlaTensor with an associated sharding spec to be received and wrapped as XLAShardedTensor. """defunwrap(elem):
returnelem.global_tensorifisinstance(elem, XLAShardedTensor) elseelemdefwrap(elem):
returnXLAShardedTensor(elem) ifisinstance(elem, torch.Tensor) elseelem# no_dispatch is only needed if you use enable_python_mode.# It prevents infinite recursion.withno_dispatch():
# re-dispatch to C++rs=tree_map(wrap,
func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
returnrs
mark_sharding API
Users can annotate native PyTorch tensors using the mark_sharding API. This takes torch.Tensor as input and returns a XLAShardedTensor as output.
@requires_pjrtdefmark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
partition_spec: Tuple[Union[int, None]]) ->XLAShardedTensor:
""" Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. Args: t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_sepc. mesh_shape (Tuple[Union[int, None]]): A int tuple describing the logical topology of the device mesh, and each element describes the number of devices in the corresponding axis. mesh (Mesh): describes the logical XLA device topology and the underlying device IDs. partition_spec (Tuple[int, None]): A tuple of device_mesh dimension index or `None`. This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise. >> input = torch.randn(8, 10) >> mesh_shape = (4, 2) >> partition_spec = (0, None) Examples —------------------------------ mesh_shape = (4, 2) num_devices = len(xm.get_xla_supported_devices()) device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) # 4-way data parallel input = torch.randn(8, 32).to(xm.xla_device()) xs.mark_sharding(input, mesh, (0, None)) # 2-way model parallel linear = nn.Linear(32, 10).to(xm.xla_device()) xs.mark_sharding(linear.weight, mesh, (None, 1)) """returnNotImplemented
Mesh
mark_sharding API takes in a logical device mesh mesh: Mesh.
classMesh:
"""Describe the logical XLA device topology mesh and the underlying resources. Args: device_ids (Union[np.ndarray, List]): A raveled list of devices (IDs) in a custom order. The list is reshaped to an `mesh_shape` array, filling the elements using C-like index order. For example, mesh_shape (Tuple[int, ...]): A int tuple describing the logical topology shape of the device mesh, and each element describes the number of devices in the corresponding axis. axis_names (Tuple[str, ...]): A sequence of resource axis names to be assigned to the dimensions of the `devices` argument. Its length should match the rank of `devices`. Example: —------------------------------ mesh_shape = (4, 2) num_devices = len(xm.get_xla_supported_devices()) device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) mesh.get_logical_mesh() >> array([[0, 1], [2, 3], [4, 5], [6, 7]]) mesh.shape() >> OrderedDict([('x', 4), ('y', 2)]) """
Sharding Specification
mark_sharding API takes mesh and partition_spec as input to annotate tensor with different sharding specifications, like replicated, tiled or partially tiled:
mesh (Mesh): A device Mesh instance describing the logical topology
of the device mesh, and each element describes the number of devices in
the corresponding axis.
partition_spec (Tuple[int, None]): A tuple of device_mesh dimension index or None.
This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
partition_spec has the same rank as the input tensor, and each dimension describes how the corresponding input tensor dimension is sharded across the device mesh (logically defined by mesh_shape). For example, an 8x32 input tensor t can be annotated as partially tiled over a 4x2 device mesh as follows:
partition_spec = (1, None) means that the first dimension (0-th index) of the input is sharded across the two device columns (mesh_shape[partition_spec[0]] = 2) and the second dimension if replicated as specified by partition_spec[1] = None. Similarly, one can replicate partition_spec = (None, None) or fully shard partition_spec = (0, 1) across the devices.
Alternatives
PyTorch only supports manual sharding API and primitives, like ShardedTensor abstraction RFC. This is great for more advanced users who would implement and run custom sharding strategies. XLAShardedTensor sharding API focuses on brining in automated, XLA compiler-based sharding to the PyTorch users.
Additional context
We also have a separate RFC for a high-level GSPMD API that will wrap an existing PyTorch module and apply sharding specs to select tensors. The high-level API will use XLAShardedTensor and mark_sharding as building blocks and make sharding annotation experience seamless and easy for the user.
🚀 Feature
We propose
XLAShardedTensor
to represent a sharded tensor that wraps aroundtorch.Tensor
, andmark_sharding()
API for tensor sharding annotation. XLAShardedTensor allows annotating tensors with sharding specs and dispatching the annotations to the XLA backend for XLA GSPMD support in PyTorch/XLA.Usage Example
Motivation
Our goal is to support GSPMD model sharding in PyTorch -- this would allow a user to bring in PyTorch models implemented (as if) on a single device and annotate a few tensors with desired sharding specs to run efficient model parallelism. Such an automated approach to model sharding allows the XLA compiler to optimize the entire computation graph end-to-end and frees up the user from implementing sharded version of ops with proper collectives in place.
The current PyTorch
ShardedTensor
abstraction RFC provides simple primitives to express sharded tensor also, and theShardedTensor
API provides a set of convenience helper functions to shard tensor or model parameters with sharding specs. Under the hood, it requires manual/explicit implementation of sharded ops (e.g., sharded version oftorch.nn.functional.linear
) and careful injection of collective comms. And the abstraction represents sharded tensors directly, not the sharding annotation for XLA compiler-based sharding that will take place lazily and support thexla
backend.Pitch
To enable our XLA compiler-based sharding, we propose
XLAShardedTensor
andmark_sharding
API. In this section, we also describe how user can specify different tensor sharding strategies for the sharding annotation.XLAShardedTensor
The main use case for XLAShardedTensor is to annotate a native torch.tensor (on a single device) with sharding spec. The annotation takes place immediately, but the actual sharding of the tensor happens lazily. Once a tensor is annotated and wrapped inside a XLAShardedTensor, it can be passed to existing PyTorch ops and nn.Module layers as torch.Tensor. This is critical to ensure that layers and tensor ops can be stacked together as before, which means that the user does not need to rewrite the existing single device model for sharded computation. Namely, XLAShardedTensor will satisfy the following requirements:
__torch_dispatch__
to send XLAShardedTensor to the XLA backend, and PyTorch/XLA should be able to retrieve attached sharding annotations to trace the graph with them and invoke SPMDPartitioner.mark_sharding API
Users can annotate native PyTorch tensors using the
mark_sharding
API. This takestorch.Tensor
as input and returns aXLAShardedTensor
as output.Mesh
mark_sharding
API takes in a logical device meshmesh: Mesh
.Sharding Specification
mark_sharding
API takesmesh
andpartition_spec
as input to annotate tensor with different sharding specifications, like replicated, tiled or partially tiled:Mesh
instance describing the logical topologyof the device mesh, and each element describes the number of devices in
the corresponding axis.
None
.This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
partition_spec
has the same rank as the input tensor, and each dimension describes how the corresponding input tensor dimension is sharded across the device mesh (logically defined bymesh_shape
). For example, an 8x32 input tensort
can be annotated as partially tiled over a 4x2 device mesh as follows:partition_spec = (1, None)
means that the first dimension (0-th index) of the input is sharded across the two device columns (mesh_shape[partition_spec[0]] = 2
) and the second dimension if replicated as specified bypartition_spec[1] = None
. Similarly, one can replicatepartition_spec = (None, None)
or fully shardpartition_spec = (0, 1)
across the devices.Alternatives
PyTorch only supports manual sharding API and primitives, like
ShardedTensor
abstraction RFC. This is great for more advanced users who would implement and run custom sharding strategies.XLAShardedTensor
sharding API focuses on brining in automated, XLA compiler-based sharding to the PyTorch users.Additional context
We also have a separate RFC for a high-level GSPMD API that will wrap an existing PyTorch
module
and apply sharding specs to select tensors. The high-level API will useXLAShardedTensor
andmark_sharding
as building blocks and make sharding annotation experience seamless and easy for the user.cc @ronghanghu @JackCaoG @miladm @pritamdamania87 @wanchaol @fduwjj @mrshenli
The text was updated successfully, but these errors were encountered: