-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PJRT] Experimental support for torch.distributed
and DDP on TPU v2/v3
#4520
Conversation
@will-cromar is this one ready for review? |
8699230
to
81b14d1
Compare
I'll take another pass tomorrow to polish and add some comments, but this should be ready for review. |
@@ -128,10 +129,11 @@ def ddp_correctness(ddp: type = torch.nn.parallel.DistributedDataParallel, | |||
offset = rank * local_batch_size | |||
for step in range(steps): | |||
# To make torch.randn produce same results across devices. | |||
torch.manual_seed(2022 + step) | |||
rng = torch.Generator().manual_seed(2022 + step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curios why we use torch.Generator() instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other option would be to wrap these randn
calls in a lock and give them a common global seed, but explicitly creating a new generator with the same seed seems clearer to me. I would have done the same for module initialization, but that case doesn't support a custom RNG.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One last comment. Maybe add a comment to the test case to suggest the reasoning behind?
pjrt._run_multiprocess( | ||
util.ddp_correctness, ddp=ddp, use_large_net=True, debug=FLAGS.debug) | ||
util.ddp_correctness, | ||
init_method='pjrt://', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder if we want to parameterized the init_method with env as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Added another test that skips for TPU <= v3, since env://
doesn't work nicely with multithreading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
…/v3 (#4520) * Implement multithreaded XLA process group * Fix tests * Merge PJRT MNIST test * formatting * Clarify random generation in test_ddp.py * Mark some variables private * Remove some extra comments * Add test that uses env:// method * Explain local RNG * Explain --pjrt_distributed flag
ThreadLocalWorld
to enable multithreadingtorch.distributed
init_method
that uses PJRT runtime parameters and supports multithreadingtorch.distributed
"rank" will become the same as our "ordinal", meaning we have one fewer set of indices to trackpjrt.DistributedDataParallel
now that the upstream version works on v3Performance comparison using ResNet50 with fake data on TPU v3:
Example usage:
Needs rebasing after #4504 merges
Follow-up: