Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Introduce global mesh #6498

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,13 @@ def test_mark_shard_scalar(self):
with self.assertRaises(AttributeError):
xt.mesh_shape

def test_global_mesh(self):
expected_mesh = self._get_mesh((1, self.n_devices))
xs.set_global_mesh(expected_mesh)
mesh = xs.get_global_mesh()

self.assertEqual(id(mesh), id(expected_mesh))


if __name__ == '__main__':
test = unittest.main()
Expand Down
5 changes: 4 additions & 1 deletion torch_xla/distributed/spmd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .xla_sharded_tensor import XLAShard, XLAShardedTensor
from .xla_sharding import (Mesh, HybridMesh, ShardingType, ShardingSpec,
XLAPatchedLinear, mark_sharding, clear_sharding,
wrap_if_sharded, xla_patched_nn_linear_forward)
wrap_if_sharded, xla_patched_nn_linear_forward,
set_global_mesh, get_global_mesh)
from .api import xla_distribute_tensor, xla_distribute_module

__all__ = [
Expand All @@ -18,4 +19,6 @@
"xla_distribute_tensor",
"xla_distribute_module",
"xla_patched_nn_linear_forward",
"set_global_mesh",
"get_global_mesh",
]
13 changes: 13 additions & 0 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,19 @@ def get_op_sharding(self,
replication_groups, sharding_type)


_GLOBAL_MESH: Mesh = None


def set_global_mesh(mesh: Mesh):
global _GLOBAL_MESH
_GLOBAL_MESH = mesh


def get_global_mesh():
global _GLOBAL_MESH
return _GLOBAL_MESH


# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4ƒ
class HybridMesh(Mesh):
"""Creates a hybrid device mesh of devices connected with ICI and DCN networks.
Expand Down
Loading