Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch DTensor device mesh interface with device type "xla" fails at get_rank() #8528

Closed
awshaichen opened this issue Jan 3, 2025 · 4 comments
Assignees

Comments

@awshaichen
Copy link
Contributor

awshaichen commented Jan 3, 2025

🐛 Bug

After constructing a PyTorch DTensor device mesh object using torch.distributed._tensor.device_mesh.init_device_mesh, the device mesh object does not support querying the rank via the get_rank interface.

To Reproduce

# test_device_mesh_get_rank.py
import os
import subprocess
import unittest
from torch.distributed._tensor.device_mesh import init_device_mesh


class TestDeviceMeshGetRank(unittest.TestCase):

    def realtest(self):
        _world_size = int(os.environ["WORLD_SIZE"])
        device_type = os.environ.get("TEST_DEVICE_TYPE", 'xla')
        if device_type == 'xla':
            from torch_xla import runtime as xr
            xr.use_spmd()
        device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(_world_size,))
        _rank = device_mesh.get_rank()
        assert _rank == int(os.environ["RANK"])

    def test_driver(self):
        if 'TEST_INTERNAL_IS_TORCHRUN' in os.environ:
            return self.realtest()
        device_count = 2
        env = os.environ.copy()
        env['TEST_INTERNAL_IS_TORCHRUN'] = '1'
        cmd = ['torchrun', '--nnodes=1', f'--nproc_per_node={device_count}', __file__]
        subprocess.check_call(cmd, env=env)


if __name__ == '__main__':
    unittest.main()

Steps to reproduce the behavior:

  1. Save the above script as test_device_mesh_get_rank.py.
  2. Executing env PJRT_DEVICE=CPU python test_device_mesh_get_rank.py under torch-xla 2.5.1 gives error message ValueError: Default process group has not been initialized, please make sure to call init_process_group..
  3. In comparison, running env TEST_DEVICE_TYPE='cuda' python test_device_mesh_get_rank.py on CUDA PyTorch can pass the test.

Expected behavior

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU and the AWS Neuron PJRT plugin
  • torch_xla version: 2.5.1

Additional context

Was trying to adapt https://github.com/pytorch/examples/blob/1bef748/distributed/tensor_parallelism/tensor_parallel_example.py for the XLA device/mesh type.

@miladm
Copy link
Collaborator

miladm commented Jan 10, 2025

@bhavya01 to assist with this issue.

cc @JackCaoG

@bhavya01
Copy link
Collaborator

The XLA backend for distributed tensors works slightly differently from native pytorch. XLA backend doesn't require creating a separate process for each device because the XLA compiler handles sharding the tensors according to the specified sharding spec.

That's why you don't see any process groups with the XLA backend here.

Please feel free to take a look at the DTensor integration RFC with XLA backend here pytorch/pytorch#92909 and let us know if you have any further questions.

The distribute_tensor and distribute_module APIs should work as expected.

@bhavya01
Copy link
Collaborator

Closing this since no comment since last week. Please feel free to re-open.

@jeffhataws
Copy link
Collaborator

@bhavya01 when will pytorch/pytorch#92909 be completed and merged into mainline?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants