From 85bb074d7110baca0c2e2cb387916ed261a61454 Mon Sep 17 00:00:00 2001 From: zhanglei335 Date: Wed, 9 Oct 2024 19:42:13 +0800 Subject: [PATCH] fix init_device_mesh for torch 2.4 --- deepspeed/comm/torch.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index ed2645d415c4..807eb919d70f 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -390,7 +390,12 @@ def init_device_mesh(self, mesh_shape, mesh_dim_names): if not required_torch_version(min_version=2.2): raise RuntimeError(f"Current torch version does not have device mesh" f"api (torch.__version__: {torch.__version__})") - return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(), + if not required_torch_version(max_version=2.4): + return torch.distributed.device_mesh.init_device_mesh(get_accelerator().device_name(), + mesh_shape, + mesh_dim_names=mesh_dim_names) + else: + return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(), mesh_shape, mesh_dim_names=mesh_dim_names)