diff --git a/tests/unit/common.py b/tests/unit/common.py index f46ac666f27b..1498b0400ee1 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -277,7 +277,11 @@ def _launch_procs(self, num_procs, init_method): self._launch_daemonic_procs(num_procs, init_method) def _dist_run(self, local_rank, num_procs, master_port, init_method, skip_msg=""): - if not dist.is_initialized(): + if dist.is_initialized(): + if get_accelerator().is_available(): + # local_rank might not match the rank in the previous run if you are reusing the environment + get_accelerator().set_device(dist.get_rank()) + else: """ Initialize deepspeed.comm and execute the user function. """ if self.set_dist_env: os.environ['MASTER_ADDR'] = '127.0.0.1'