-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPMD] Hybrid Device mesh creation #5147
Conversation
cc @alanwaketan |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great Mohit! Could we also add some basic unit tests in https://github.com/pytorch/xla/blob/master/test/spmd/test_xla_sharding.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will help me easier to read your code if you can briefly describe what you did in the PR especially on describing some of the key complex functionalities in the code, like _create_hybrid_device_mesh.
Or you can leave a comment on the code.
out[coords[0], coords[1], coords[2]] = d | ||
return out | ||
|
||
def _create_device_mesh_for_nd_torus( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain how this function optimize the performance according to the TPU physical topology? What's the algorithm? Is it the inner ring has the highest performance, so we should assign the back of the mesh_shape to it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Speaking with Mohit offline. The rule is that the TPU topology is always 3D. And the inner 2D tensors have a faster ICI than the ones connect across them. Therefore, we should group the most speed demanding rank, i.e., highest rank of the mesh, to the inner 2D tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that I read more into the code. This algorithm seems quite restrict:
- It only works with mapping a 2D or 3D logical mesh into the 3D physical mesh.
- Then for 3D mesh, I think the logical mesh needs to be a transpose of the physical mesh.
- Then for 2D mesh, it's just trying to map a combination of the axes into each of the dimension of the logical mesh.
After these simple rules, it then makes sure that devices that are physically close to each other are assigned close to each other in the logical mesh as well. For example, assuming the logical mesh is 2D, the devices that are in mesh[0] are always be a 2D slice of the 3D physical mesh.
If my understanding is correct, @khatwanimohit can you polish my comments and make it into the comment of this helper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can add:
This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64.
ce4f052
to
9c6d8ab
Compare
hybrid_mesh = xs.HybridMesh( | ||
ici_mesh_shape=(1, 4), dcn_mesh_shape=(num_slices, 1)) | ||
print(hybrid_mesh.get_logical_mesh()) | ||
self.assertEqual(hybrid_mesh.get_logical_mesh().tolist(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this result respect the _create_device_mesh_for_nd_torus algorithm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have confirmed this with the jax's mesh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make the ici_mesh_shap=(2, 2)
? I think that can better show how the algorithm works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed ici_mesh_shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just noticed that most of the helpers @khatwanimohit you introduced are inspired by https://github.com/google/jax/blob/bfe8acb31e04a540daad3f568239ec0e5c3f0d0f/jax/experimental/mesh_utils.py. And in fact, all those helpers have a very nice docstring to explain what the helpers are doing.
I recommend next time if you are going to import some JAX utils to PyTorch/XLA, you'd better:
- List the source on each utils you imported.
- Import their docstring as well. Those are really critical for the readability of the code.
Also, have you checked the licenses to make sure that you can copy code from JAX into PyTorch/XLA? If not, I can do the research for you.
79336d3
to
572548b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looking good to me. Thanks, @khatwanimohit.
Please address the comments on readability.
|
||
def get_logical_mesh(self): | ||
return self.device_ids.reshape(self.mesh_shape) | ||
|
||
|
||
# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make it per helper that you imported?
out[coords[0], coords[1], coords[2]] = d | ||
return out | ||
|
||
def _create_device_mesh_for_nd_torus( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can add:
This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64.
super().__init__(device_ids, mesh_shape, axis_names) | ||
|
||
def _get_physical_tpu_mesh(self, devices: Sequence[Any]) -> np.ndarray: | ||
r"""Rearrange TPU devices in a slice into a physical mesh.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add:
1.
This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L172
- The following description of the function:
r"""Rearrange TPU devices in a slice into a physical mesh.
Args:
devices: A list of device logical ordinals in a TPU slice.
Returns:
A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On
v2 and v3, global_z is instead cores_per_chip (i.e., 2).
"""
physical_mesh, mesh_shape) | ||
return device_mesh | ||
|
||
def _create_hybrid_device_mesh(self, ici_mesh_shape: Sequence[int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add:
1.
This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L288.
- And the follow function description:
"""Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.
Args:
ici_mesh_shape: shape of the logical mesh for the faster/inner network, ordered
by increasing network intensity, e.g. [replica, data, mdl] where mdl has
the most network communication requirements.
dcn_mesh_shape: shape of the logical mesh for the slower/outer network,
in the same order as mesh_shape.
Returns:
A np.ndarray of device logical ordinal with ici_mesh_shape * dcn_mesh_shape as its shape
that can be fed into HybridMesh for hybrid parallelism.
"""
return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment | ||
|
||
# This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L231 | ||
def _create_device_mesh(self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't mention this one given your logic is quite different. I suggest you can undo it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed the comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks, Mohit.
The TPU CI broke after this PR merged. Is this related?
|
Let's have a follow up to disable the test for TPU. You can do that by following: https://github.com/pytorch/xla/blob/master/test/test_zero1.py#L13 |
No description provided.