From 53524f12faf8f2ba9c026042edb46939e82b733c Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 31 Oct 2024 22:51:32 +0000 Subject: [PATCH] explictly set device when reusing dist env --- tests/unit/common.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/common.py b/tests/unit/common.py index f46ac666f27b..1498b0400ee1 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -277,7 +277,11 @@ def _launch_procs(self, num_procs, init_method): self._launch_daemonic_procs(num_procs, init_method) def _dist_run(self, local_rank, num_procs, master_port, init_method, skip_msg=""): - if not dist.is_initialized(): + if dist.is_initialized(): + if get_accelerator().is_available(): + # local_rank might not match the rank in the previous run if you are reusing the environment + get_accelerator().set_device(dist.get_rank()) + else: """ Initialize deepspeed.comm and execute the user function. """ if self.set_dist_env: os.environ['MASTER_ADDR'] = '127.0.0.1'