Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PJRT] Experimental support for torch.distributed and DDP on TPU v2/v3 #4520

Merged
merged 10 commits into from
Jan 28, 2023

Conversation

will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented Jan 26, 2023

  • Use experimental ThreadLocalWorld to enable multithreading
  • Implement a new torch.distributed init_method that uses PJRT runtime parameters and supports multithreading
  • torch.distributed "rank" will become the same as our "ordinal", meaning we have one fewer set of indices to track
  • Converge the PJRT ImageNet and MNIST tests with the original ones
  • Remove experimental pjrt.DistributedDataParallel now that the upstream version works on v3

Performance comparison using ResNet50 with fake data on TPU v3:

Runtime DDP Throughput (ex/sec/replica)
XRT No 418.54
XRT Yes 395.97
PJRT No ~640
PJRT Yes ~565

Example usage:

import torch
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.experimental.pjrt_backend
from torch_xla.experimental import pjrt

def _all_gather(index: int):
  dist.init_process_group('xla', init_method='pjrt://')
  t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
  output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
  dist.all_gather(output, t)
  xm.mark_step()
  print(output)

if __name__ == '__main__':
  xmp.spawn(self._all_gather)

Needs rebasing after #4504 merges

Follow-up:

  • Update automated tests
  • Update PJRT and DDP documentation before release

@JackCaoG
Copy link
Collaborator

@will-cromar is this one ready for review?

@will-cromar will-cromar marked this pull request as ready for review January 27, 2023 01:48
@will-cromar
Copy link
Collaborator Author

I'll take another pass tomorrow to polish and add some comments, but this should be ready for review.

@will-cromar will-cromar removed the DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing label Jan 27, 2023
@@ -128,10 +129,11 @@ def ddp_correctness(ddp: type = torch.nn.parallel.DistributedDataParallel,
offset = rank * local_batch_size
for step in range(steps):
# To make torch.randn produce same results across devices.
torch.manual_seed(2022 + step)
rng = torch.Generator().manual_seed(2022 + step)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curios why we use torch.Generator() instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other option would be to wrap these randn calls in a lock and give them a common global seed, but explicitly creating a new generator with the same seed seems clearer to me. I would have done the same for module initialization, but that case doesn't support a custom RNG.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last comment. Maybe add a comment to the test case to suggest the reasoning behind?

pjrt._run_multiprocess(
util.ddp_correctness, ddp=ddp, use_large_net=True, debug=FLAGS.debug)
util.ddp_correctness,
init_method='pjrt://',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we want to parameterized the init_method with env as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Added another test that skips for TPU <= v3, since env:// doesn't work nicely with multithreading.

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@will-cromar will-cromar merged commit 021a1cc into master Jan 28, 2023
ManfeiBai pushed a commit that referenced this pull request Jan 30, 2023
…/v3 (#4520)

* Implement multithreaded XLA process group

* Fix tests

* Merge PJRT MNIST test

* formatting

* Clarify random generation in test_ddp.py

* Mark some variables private

* Remove some extra comments

* Add test that uses env:// method

* Explain local RNG

* Explain --pjrt_distributed flag
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants