Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Hybrid Device mesh creation #5147

Merged
merged 14 commits into from
Jun 19, 2023
128 changes: 126 additions & 2 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, field
import torch
import torch_xla
Expand All @@ -8,7 +8,8 @@
import torch_xla.runtime as xr

import numpy as np
from typing import Tuple, Union, List
import itertools
from typing import Tuple, Union, List, Sequence, Any, Optional
from enum import IntEnum


Expand Down Expand Up @@ -70,6 +71,129 @@ def get_logical_mesh(self):
return self.device_ids.reshape(self.mesh_shape)


class HybridMesh(Mesh):
device_ids: np.ndarray
ici_mesh_shape: Tuple[int, ...]
dcn_mesh_shape: Tuple[int, ...]
axis_names: Tuple[str, ...]

def __init__(self,
device_ids: Union[np.ndarray, List],
ici_mesh_shape: Tuple[int, ...],
dcn_mesh_shape: Tuple[int, ...] = None,
axis_names: Tuple[str, ...] = None):
if dcn_mesh_shape == None:
dcn_mesh_shape = tuple([1] * len(ici_mesh_shape))
mesh_shape = tuple([x * y for x, y in zip(ici_mesh_shape, dcn_mesh_shape)])
assert len(ici_mesh_shape) == len(dcn_mesh_shape)
super().__init__(device_ids, mesh_shape, axis_names)
self.device_attributes = pjrt.global_device_attributes()
if 'slice_index' in self.device_attributes[0] and np.prod(
dcn_mesh_shape) == 1:
raise ValueError('Provide dcn_mesh_shape to create a mesh for multislice')
self.device_ids = device_ids
self.ici_mesh_shape = ici_mesh_shape
self.dcn_mesh_shape = dcn_mesh_shape
self.axis_names = axis_names

def _get_physical_tpu_mesh(self, devices: Sequence[Any]) -> np.ndarray:
r"""Rearrange TPU devices in a slice into a physical mesh."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add:
1.

This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L172
  1. The following description of the function:
  r"""Rearrange TPU devices in a slice into a physical mesh.

  Args:
    devices: A list of device logical ordinals in a TPU slice.

  Returns:
    A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On
      v2 and v3, global_z is instead cores_per_chip (i.e., 2).
  """

device_coords = [self.device_attributes[d]['coords'] for d in devices]
dims = tuple(d + 1 for d in max(device_coords))
out = np.empty(dims, dtype=object)
for coords, d in zip(device_coords, devices):
out[coords[0], coords[1], coords[2]] = d
return out

def _create_device_mesh_for_nd_torus(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how this function optimize the performance according to the TPU physical topology? What's the algorithm? Is it the inner ring has the highest performance, so we should assign the back of the mesh_shape to it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speaking with Mohit offline. The rule is that the TPU topology is always 3D. And the inner 2D tensors have a faster ICI than the ones connect across them. Therefore, we should group the most speed demanding rank, i.e., highest rank of the mesh, to the inner 2D tensors.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I read more into the code. This algorithm seems quite restrict:

  1. It only works with mapping a 2D or 3D logical mesh into the 3D physical mesh.
  2. Then for 3D mesh, I think the logical mesh needs to be a transpose of the physical mesh.
  3. Then for 2D mesh, it's just trying to map a combination of the axes into each of the dimension of the logical mesh.

After these simple rules, it then makes sure that devices that are physically close to each other are assigned close to each other in the logical mesh as well. For example, assuming the logical mesh is 2D, the devices that are in mesh[0] are always be a 2D slice of the 3D physical mesh.

If my understanding is correct, @khatwanimohit can you polish my comments and make it into the comment of this helper?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add:

This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64.

self, physical_mesh: np.ndarray,
mesh_shape: Sequence[int]) -> Tuple[np.ndarray, List[Tuple[int, ...]]]:
# Remaining physical axes to be assigned to logical axes.
assignable_physical_mesh = list(physical_mesh.shape)
# Map each logical axis to a subset of physical axes.
assignment: List[Tuple[int, ...]] = [() for _ in mesh_shape]
# Assign logical axes from highest network intensity to lowest.
# `mesh_shape` is assumed to ordered by lowest network intensity first, so
# reverse it first.
for logical_axis_index, logical_axis_size in reversed(
list(enumerate(mesh_shape))):
for num_axes in range(3, 0, -1):
axes = itertools.combinations(assignable_physical_mesh, num_axes)
indices = itertools.combinations(
range(len(assignable_physical_mesh)), num_axes)
for c_axes, c_indices in zip(axes, indices):
if np.product(c_axes) == logical_axis_size:
assignment[logical_axis_index] = c_indices
# Zero the assigned physical axes.
assignable_physical_mesh = [
0 if i in c_indices else v
for i, v in enumerate(assignable_physical_mesh)
]
break
if assignment[logical_axis_index]:
# We already found an assignment from one candidate above.
break
else:
# If the num_axes for loop did not break, i.e. none of the candidates work
# goto here with this while-else construct.
if logical_axis_size > 1:
raise NotImplementedError(
'Failed to find assignment for logical_axis_index'
f' {logical_axis_index} of size {logical_axis_size} with remaining'
f' assignable mesh {assignable_physical_mesh}. The size of each'
' axis in your logical mesh must be equal to the product of'
' some subset of the physical mesh axis sizes. E.g logical mesh (4,'
' 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.'
)
# Flatten the assignment
transpose: List[int] = []
for x in assignment:
for y in x:
transpose.append(int(y))
return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment

def _create_device_mesh(self, mesh_shape: Sequence[int],
devices: Sequence[Any]) -> np.ndarray:
if np.prod(mesh_shape) != len(devices):
raise ValueError(
f'Number of devices {len(devices)} must equal the product '
f'of mesh_shape {mesh_shape}')
physical_mesh = self._get_physical_tpu_mesh(devices)
device_mesh, assignment = self._create_device_mesh_for_nd_torus(
physical_mesh, mesh_shape)
return device_mesh

def _create_hybrid_device_mesh(self, mesh_shape: Sequence[int],
dcn_mesh_shape: Sequence[int],
devices: Sequence[Any]) -> np.ndarray:
granule_dict = defaultdict(list)
slice_index_attr = [d['slice_index'] for d in self.device_attributes]
for d, dev in enumerate(self.device_attributes):
granule_dict[dev['slice_index']].append(d)
granules = list(granule_dict[key] for key in sorted(granule_dict.keys()))
if np.prod(dcn_mesh_shape) != len(granules):
raise ValueError(
f'Number of slices {len(granules)} must equal the product of '
f'dcn_mesh_shape {dcn_mesh_shape}')
per_granule_meshes = [
self._create_device_mesh(mesh_shape, granule) for granule in granules
]
granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape)
blocks = np.vectorize(
lambda i: per_granule_meshes[i], otypes=[object])(
granule_mesh)
device_mesh = np.block(blocks.tolist())
return device_mesh

def get_logical_mesh(self):
if np.prod(self.dcn_mesh_shape) > 1: # multislice
return self._create_hybrid_device_mesh(self.ici_mesh_shape,
self.dcn_mesh_shape,
self.device_ids)
# single slice
return self._create_device_mesh(self.ici_mesh_shape, self.device_ids)


class ShardingType(IntEnum):
# ShardingType enum ID maps to OpSharidng.Type if applicable.
REPLICATED = 0
Expand Down