Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce torch_xla.launch() #7648

Merged
merged 4 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions examples/data_parallel/train_resnet_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla
import torch_xla.core.xla_model as xm


class TrainResNetDDP(TrainResNetBase):

def __init__(self):
super().__init__()
dist.init_process_group('xla', init_method='xla://')
super().__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should the super init stay at the top?

Copy link
Collaborator Author

@zpcore zpcore Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, super().__init__() will call some functions from runtime, which needs init the TPU backend first.

self.model = DDP(
self.model, gradient_as_bucket_view=True, broadcast_buffers=False)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
Expand All @@ -26,5 +28,7 @@ def _mp_fn(index):


if __name__ == '__main__':
print('consider using train_resnet_spmd_data_parallel.py instead to get better performance')
xmp.spawn(_mp_fn, args=())
print(
'consider using train_resnet_spmd_data_parallel.py instead to get better performance'
)
torch_xla.launch(_mp_fn)
9 changes: 3 additions & 6 deletions test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@
import torch_xla.distributed.parallel_loader as pl
import torch_xla.debug.profiler as xp
import torch_xla.utils.utils as xu
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

import torch.distributed as dist
Expand Down Expand Up @@ -375,7 +373,6 @@ def _mp_fn(index, flags):


if __name__ == '__main__':
if dist.is_torchelastic_launched():
_mp_fn(xu.getenv_as(xenv.LOCAL_RANK, int), FLAGS)
else:
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)
# if running with torchrun, nprocs argument will be omitted.
debug_single_process = True if FLAGS.num_cores == 1 else False
torch_xla.launch(_mp_fn, args=(FLAGS,), debug_single_process=True)
31 changes: 30 additions & 1 deletion torch_xla/torch_xla.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import contextlib
from typing import List
from typing import Callable, List, Tuple

import torch
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_env_vars as xenv
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu


def device(index: int = None) -> torch.device:
Expand Down Expand Up @@ -80,3 +85,27 @@ def manual_seed(seed, device=None):
If missing the default device seed will be set.
"""
xm.set_rng_state(seed, device)


def launch(
fn: Callable,
args: Tuple = (),
start_method: str = 'spawn',
debug_single_process: bool = False,
):
""" Entry to launch multiprocess.

Raises:
NotImplementedError: SPMD is not supported yet.
"""
if xr.is_spmd():
# TODO(piz): SPMD is specified differently from mp. Skip for now.
raise NotImplementedError(
'launch function does not support SPMD at this time')

nprocs = 1 if debug_single_process else None

if dist.is_torchelastic_launched():
fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
else:
xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
Loading