-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi-host GPU support #5657
Conversation
@wbmc I have granted you the write access |
OK, Thanks! |
std::string dist_service_addr = | ||
runtime::sys_util::GetEnvString("PJRT_DIST_SERVICE_ADDR", ""); | ||
runtime::sys_util::GetEnvString("MASTER_ADDR", "127.0.0.1") + ":" + port; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better readability, can we introduce a variable to describe this IP address is the default parameter? e.g. LOCAL_HOST_IP_DEFAULT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great Xiongfei!
auto distributed_client = | ||
MaybeInitializeDistributedRuntimeClient(local_rank, dist_service_addr); | ||
MaybeInitializeDistributedRuntimeClient(global_rank); | ||
auto allowed_devices = | ||
std::make_optional<std::set<int>>(std::set{local_rank}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to generalize this to support CUDA_VISIBLE_DEVICES
and single-process-multi-device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we initially plan to incorporate CUDA_VISIBLE_DEVICES into this PR but we encountered some errors such as #5558 (comment) and #5558 (comment). We still plan to do it but probably in a follow-up pr
torch_xla/_internal/gpu.py
Outdated
@@ -14,7 +15,8 @@ def num_local_processes() -> int: | |||
""" | |||
assert xenv.GPU_NUM_DEVICES in os.environ, \ | |||
"Must set `GPU_NUM_DEVICES` environment variable to use the PjRt GPU client" | |||
return int(os.environ[xenv.GPU_NUM_DEVICES]) | |||
os.environ[xenv.WORLD_SIZE] = os.environ[xenv.GPU_NUM_DEVICES] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this right? We'll clobber the torchrun-set world size. Also wondering if we need to keep GPU_NUM_DEVICES
in the first place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is mainly to make single-host case work with spawn
(e.g. the tests in pytorch/xla/test/pjrt/test_runtime_gpu.py in which we use spawn
). To provide a similar UX as PyTorch, we should still support spawn
for the single host case (fwiw, PyTorch supports it as https://screenshot.googleplex.com/7nKD68dXNUUskF7). But I like your idea of replacing "GPU_NUM_DEVICES" with "WORLD_SIZE".
Also, torchrun doesn't invoke spawn
so this function wouldn't be called hence it doesn't overwrite the torchrun-set world size.
So how about I replace GPU_NUM_DEVICES
with WORLD_SIZE
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WORLD_SIZE
isn't quite right here, since this function returns the expected number of local processes. In torchrun
, that's set as LOCAL_WORLD_SIZE
. In xmp.spawn
, we had to set it as something different like PJRT_LOCAL_WORLD_SIZE
because LOCAL_WORLD_SIZE
caused some issue with a third party package.
So maybe os.environ.get('LOCAL_WORLD_SIZE') or os.environ.get('PJRT_LOCAL_WORLD_SIZE')
? It's clunky, but it covers both cases.
To use a subset of local GPUs with xmp.spawn
, a user could set then LOCAL_WORLD_SIZE
themselves instead of GPU_NUM_DEVICES
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed that WORLD_SIZE
is confusing. OTOH, WORLD_SIZE
is what PyTorch single-node training uses: https://screenshot.googleplex.com/BH27HAYTpbNU4KA, if we want to stay closer to PyTorch.
I think clarity is more important here. How do you think we replace GPU_NUM_DEVICES
with PJRT_LOCAL_WORLD_SIZE
(or something like PT_XLA_LOCAL_WORLD_SIZE
if we don't want to leak the underlying implementation detail), so the way we run single-host-multi-GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py ...
turns to PJRT_LOCAL_WORLD_SIZE=4 python3 xla/test/test_train_mp_imagenet.p ..
? @will-cromar @jonb377
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's okay to rely on the LOCAL_WORLD_SIZE
env var here. We know torchrun will set it for sure, and the manual single-host-multi-GPU case can become PJRT_DEVICE=GPU LOCAL_WORLD_SIZE=4 python script.py
*. We can get rid of GPU_NUM_DEVICES
.
* As a follow up, I would like to implement more automatic configuration like we have with TPUs so users don't have to set anything in the default case
We definitely don't want to override torchrun
's settings here.
I forgot that PJRT_LOCAL_WORLD_SIZE
is set after this function is called (and probably just set to the output of this function). So we can ignore that variable here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I used LOCAL_WORLD_SIZE
here.
I'll replace GPU_NUM_DEVICES
with LOCAL_WORLD_SIZE
in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! I'm really excited to see your progress here.
While this PR is in review, can you start updating the documentation for GPUs in this repository? We should have some documentation covering how/why the GPU runtime works and make it clear that we expect people to use torchrun
for multihost use cases.
0cb25b4
to
5bd9939
Compare
torch_xla/_internal/gpu.py
Outdated
@@ -14,7 +15,8 @@ def num_local_processes() -> int: | |||
""" | |||
assert xenv.GPU_NUM_DEVICES in os.environ, \ | |||
"Must set `GPU_NUM_DEVICES` environment variable to use the PjRt GPU client" | |||
return int(os.environ[xenv.GPU_NUM_DEVICES]) | |||
os.environ[xenv.WORLD_SIZE] = os.environ[xenv.GPU_NUM_DEVICES] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's okay to rely on the LOCAL_WORLD_SIZE
env var here. We know torchrun will set it for sure, and the manual single-host-multi-GPU case can become PJRT_DEVICE=GPU LOCAL_WORLD_SIZE=4 python script.py
*. We can get rid of GPU_NUM_DEVICES
.
* As a follow up, I would like to implement more automatic configuration like we have with TPUs so users don't have to set anything in the default case
We definitely don't want to override torchrun
's settings here.
I forgot that PJRT_LOCAL_WORLD_SIZE
is set after this function is called (and probably just set to the output of this function). So we can ignore that variable here.
I added the documentation in #5704. Feel free to take a look as well. I'll do some testing for this feature and may make some fixes if necessary. Once the feature is more stable, I'll merge the GPU documentation PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple of minor nits still open. Otherwise LGTM!
6e836ba
to
df4e450
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks Xiongfei!
* add prints * to be continued. * made torchrun works on single host * Add an example of resnet torchrun * add prints * use local rank for allowed_devices. * remove unwanted comments * remove comments * Add torchrun test to the CI. * added a ll_reduce test * fix ci failures * remove some comments * provide an alternative way to set the port for coordinator. * fix test by destroying the process group after the test * fix the single host test. * fix single host gpu tests. * add reduce scatter test * fix comments * fix a comment * fix comments * fix linter * fix comments * Use Local_WORLD_SIZE for spawn case. * fix more comments
* add prints * to be continued. * made torchrun works on single host * Add an example of resnet torchrun * add prints * use local rank for allowed_devices. * remove unwanted comments * remove comments * Add torchrun test to the CI. * added a ll_reduce test * fix ci failures * remove some comments * provide an alternative way to set the port for coordinator. * fix test by destroying the process group after the test * fix the single host test. * fix single host gpu tests. * add reduce scatter test * fix comments * fix a comment * fix comments * fix linter * fix comments * Use Local_WORLD_SIZE for spawn case. * fix more comments
* add prints * to be continued. * made torchrun works on single host * Add an example of resnet torchrun * add prints * use local rank for allowed_devices. * remove unwanted comments * remove comments * Add torchrun test to the CI. * added a ll_reduce test * fix ci failures * remove some comments * provide an alternative way to set the port for coordinator. * fix test by destroying the process group after the test * fix the single host test. * fix single host gpu tests. * add reduce scatter test * fix comments * fix a comment * fix comments * fix linter * fix comments * Use Local_WORLD_SIZE for spawn case. * fix more comments
* add prints * to be continued. * made torchrun works on single host * Add an example of resnet torchrun * add prints * use local rank for allowed_devices. * remove unwanted comments * remove comments * Add torchrun test to the CI. * added a ll_reduce test * fix ci failures * remove some comments * provide an alternative way to set the port for coordinator. * fix test by destroying the process group after the test * fix the single host test. * fix single host gpu tests. * add reduce scatter test * fix comments * fix a comment * fix comments * fix linter * fix comments * Use Local_WORLD_SIZE for spawn case. * fix more comments
* add prints * to be continued. * made torchrun works on single host * Add an example of resnet torchrun * add prints * use local rank for allowed_devices. * remove unwanted comments * remove comments * Add torchrun test to the CI. * added a ll_reduce test * fix ci failures * remove some comments * provide an alternative way to set the port for coordinator. * fix test by destroying the process group after the test * fix the single host test. * fix single host gpu tests. * add reduce scatter test * fix comments * fix a comment * fix comments * fix linter * fix comments * Use Local_WORLD_SIZE for spawn case. * fix more comments
Collaborating with @wbmc
To start the multi-host (or multi-node) training, do:
on each host.
The documentation will be updated in a follow-up PR.