Skip to content

Commit

Permalink
Support 3rd-party distributed backend (#706)
Browse files Browse the repository at this point in the history
Support for 3rd-party distributed backends after registering them using
[torch.distributed.Backend.register_backend](https://pytorch.org/docs/stable/distributed.html#torch.distributed.Backend.register_backend)

Co-authored-by: zhangqiongwen <zhangqiongwen@huawei.com>
  • Loading branch information
zqwenn and zhangqiongwen authored Dec 3, 2024
1 parent 4d88a8a commit 93ba30f
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,15 @@ def context(cp_context: Optional[Generator[None, None, None]] = None):
return context


def _get_distributed_backend(job_config):
backend = "nccl"
if device_type in torch.distributed.Backend.default_device_backend_map.keys():
backend = torch.distributed.Backend.default_device_backend_map.get(device_type)
if job_config.training.enable_cpu_offload:
backend = f"{device_type}:{backend},cpu:gloo"
return backend


def init_distributed(job_config):
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055
Expand All @@ -193,11 +202,8 @@ def init_distributed(job_config):
# such as those in tensor parallelism
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

backend = "nccl"
if job_config.training.enable_cpu_offload:
backend = "cuda:nccl,cpu:gloo"
torch.distributed.init_process_group(
backend=backend,
backend=_get_distributed_backend(job_config),
timeout=timedelta(seconds=job_config.comm.init_timeout_seconds),
)

Expand Down

0 comments on commit 93ba30f

Please sign in to comment.