From b24dfa9d0822d36b35d9fa214ce432cef5daf1b6 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Fri, 1 Nov 2024 05:57:47 -0700 Subject: [PATCH] Explictly set device when reusing dist env (#6696) A rank of a process can change when reusing the environment. This PR explicitly sets the device when reusing the environment. --- tests/unit/common.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/common.py b/tests/unit/common.py index f46ac666f27b..1498b0400ee1 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -277,7 +277,11 @@ def _launch_procs(self, num_procs, init_method): self._launch_daemonic_procs(num_procs, init_method) def _dist_run(self, local_rank, num_procs, master_port, init_method, skip_msg=""): - if not dist.is_initialized(): + if dist.is_initialized(): + if get_accelerator().is_available(): + # local_rank might not match the rank in the previous run if you are reusing the environment + get_accelerator().set_device(dist.get_rank()) + else: """ Initialize deepspeed.comm and execute the user function. """ if self.set_dist_env: os.environ['MASTER_ADDR'] = '127.0.0.1'