Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Hybrid Device mesh creation #5147

Merged
merged 14 commits into from
Jun 19, 2023
Merged

[SPMD] Hybrid Device mesh creation #5147

merged 14 commits into from
Jun 19, 2023

Conversation

khatwanimohit
Copy link
Collaborator

No description provided.

@khatwanimohit khatwanimohit requested review from yeounoh and jonb377 June 8, 2023 20:33
@khatwanimohit
Copy link
Collaborator Author

cc @alanwaketan

Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great Mohit! Could we also add some basic unit tests in https://github.com/pytorch/xla/blob/master/test/spmd/test_xla_sharding.py?

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will help me easier to read your code if you can briefly describe what you did in the PR especially on describing some of the key complex functionalities in the code, like _create_hybrid_device_mesh.

Or you can leave a comment on the code.

out[coords[0], coords[1], coords[2]] = d
return out

def _create_device_mesh_for_nd_torus(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how this function optimize the performance according to the TPU physical topology? What's the algorithm? Is it the inner ring has the highest performance, so we should assign the back of the mesh_shape to it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speaking with Mohit offline. The rule is that the TPU topology is always 3D. And the inner 2D tensors have a faster ICI than the ones connect across them. Therefore, we should group the most speed demanding rank, i.e., highest rank of the mesh, to the inner 2D tensors.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I read more into the code. This algorithm seems quite restrict:

  1. It only works with mapping a 2D or 3D logical mesh into the 3D physical mesh.
  2. Then for 3D mesh, I think the logical mesh needs to be a transpose of the physical mesh.
  3. Then for 2D mesh, it's just trying to map a combination of the axes into each of the dimension of the logical mesh.

After these simple rules, it then makes sure that devices that are physically close to each other are assigned close to each other in the logical mesh as well. For example, assuming the logical mesh is 2D, the devices that are in mesh[0] are always be a 2D slice of the 3D physical mesh.

If my understanding is correct, @khatwanimohit can you polish my comments and make it into the comment of this helper?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add:

This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64.

hybrid_mesh = xs.HybridMesh(
ici_mesh_shape=(1, 4), dcn_mesh_shape=(num_slices, 1))
print(hybrid_mesh.get_logical_mesh())
self.assertEqual(hybrid_mesh.get_logical_mesh().tolist(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this result respect the _create_device_mesh_for_nd_torus algorithm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have confirmed this with the jax's mesh

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the ici_mesh_shap=(2, 2)? I think that can better show how the algorithm works?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed ici_mesh_shape

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed that most of the helpers @khatwanimohit you introduced are inspired by https://github.com/google/jax/blob/bfe8acb31e04a540daad3f568239ec0e5c3f0d0f/jax/experimental/mesh_utils.py. And in fact, all those helpers have a very nice docstring to explain what the helpers are doing.

I recommend next time if you are going to import some JAX utils to PyTorch/XLA, you'd better:

  1. List the source on each utils you imported.
  2. Import their docstring as well. Those are really critical for the readability of the code.

Also, have you checked the licenses to make sure that you can copy code from JAX into PyTorch/XLA? If not, I can do the research for you.

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looking good to me. Thanks, @khatwanimohit.

Please address the comments on readability.


def get_logical_mesh(self):
return self.device_ids.reshape(self.mesh_shape)


# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make it per helper that you imported?

out[coords[0], coords[1], coords[2]] = d
return out

def _create_device_mesh_for_nd_torus(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add:

This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64.

super().__init__(device_ids, mesh_shape, axis_names)

def _get_physical_tpu_mesh(self, devices: Sequence[Any]) -> np.ndarray:
r"""Rearrange TPU devices in a slice into a physical mesh."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add:
1.

This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L172
  1. The following description of the function:
  r"""Rearrange TPU devices in a slice into a physical mesh.

  Args:
    devices: A list of device logical ordinals in a TPU slice.

  Returns:
    A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On
      v2 and v3, global_z is instead cores_per_chip (i.e., 2).
  """

physical_mesh, mesh_shape)
return device_mesh

def _create_hybrid_device_mesh(self, ici_mesh_shape: Sequence[int],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add:
1.

This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L288.
  1. And the follow function description:
"""Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.

  Args:
    ici_mesh_shape: shape of the logical mesh for the faster/inner network, ordered
      by increasing network intensity, e.g. [replica, data, mdl] where mdl has
      the most network communication requirements.
    dcn_mesh_shape: shape of the logical mesh for the slower/outer network,
      in the same order as mesh_shape.

  Returns:
    A np.ndarray of device logical ordinal with ici_mesh_shape * dcn_mesh_shape as its shape
    that can be fed into HybridMesh for hybrid parallelism.
  """

return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment

# This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L231
def _create_device_mesh(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't mention this one given your logic is quite different. I suggest you can undo it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the comment

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks, Mohit.

@khatwanimohit khatwanimohit merged commit 60a6d60 into master Jun 19, 2023
@will-cromar
Copy link
Collaborator

The TPU CI broke after this PR merged. Is this related?

Step #4 - "run_e2e_tests": ======================================================================
Step #4 - "run_e2e_tests": ERROR: test_hybrid_mesh_shape (__main__.BasicShardingTest)
Step #4 - "run_e2e_tests": ----------------------------------------------------------------------
Step #4 - "run_e2e_tests": Traceback (most recent call last):
Step #4 - "run_e2e_tests":   File "/src/pytorch/xla/test/spmd/test_xla_sharding.py", line 462, in test_hybrid_mesh_shape
Step #4 - "run_e2e_tests":     hybrid_mesh = self._get_hybrid_mesh((1, self.n_devices))
Step #4 - "run_e2e_tests":   File "/src/pytorch/xla/test/spmd/test_xla_sharding_base.py", line 42, in _get_hybrid_mesh
Step #4 - "run_e2e_tests":     return xs.HybridMesh(ici_mesh_shape=ici_mesh_shape)
Step #4 - "run_e2e_tests":   File "/usr/local/lib/python3.8/site-packages/torch_xla/experimental/xla_sharding.py", line 122, in __init__
Step #4 - "run_e2e_tests":     mesh = self._create_device_mesh(self.ici_mesh_shape)
Step #4 - "run_e2e_tests":   File "/usr/local/lib/python3.8/site-packages/torch_xla/experimental/xla_sharding.py", line 257, in _create_device_mesh
Step #4 - "run_e2e_tests":     device_mesh, assignment = self._create_device_mesh_for_nd_torus(
Step #4 - "run_e2e_tests":   File "/usr/local/lib/python3.8/site-packages/torch_xla/experimental/xla_sharding.py", line 220, in _create_device_mesh_for_nd_torus
Step #4 - "run_e2e_tests":     raise NotImplementedError(
Step #4 - "run_e2e_tests": NotImplementedError: Failed to find assignment for logical_axis_index 1 of size 8 with remaining assignable mesh [2, 2, 1]. The size of each axis in your logical mesh must be equal to the product of some subset of the physical mesh axis sizes. E.g logical mesh (4, 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.
Step #4 - "run_e2e_tests": 
Step #4 - "run_e2e_tests": ----------------------------------------------------------------------
Step #4 - "run_e2e_tests": Ran 26 tests in 0.968s
Step #4 - "run_e2e_tests": 
Step #4 - "run_e2e_tests": FAILED (errors=1)
Step #4 - "run_e2e_tests": [[0 1]
Step #4 - "run_e2e_tests":  [2 3]
Step #4 - "run_e2e_tests":  [4 5]
Step #4 - "run_e2e_tests":  [6 7]]
Step #4 - "run_e2e_tests": ++ kubectl get pod/xla-test-job-kl46l -o 'jsonpath={.status.containerStatuses[?(@.name=="xla-test")].state.terminated.exitCode}'

@alanwaketan
Copy link
Collaborator

Let's have a follow up to disable the test for TPU.

You can do that by following: https://github.com/pytorch/xla/blob/master/test/test_zero1.py#L13

ManfeiBai pushed a commit that referenced this pull request Jun 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants