Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-host GPU support #5657

Merged
merged 24 commits into from
Oct 19, 2023
Merged

Add multi-host GPU support #5657

merged 24 commits into from
Oct 19, 2023

Conversation

vanbasten23
Copy link
Collaborator

@vanbasten23 vanbasten23 commented Sep 29, 2023

Collaborating with @wbmc

To start the multi-host (or multi-node) training, do:

PJRT_DEVICE=GPU torchrun \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_HOST_RANK} \
--rdzv_endpoint=<internal_ip_address> multinode_training_script.py

on each host.

The documentation will be updated in a follow-up PR.

@JackCaoG
Copy link
Collaborator

@wbmc I have granted you the write access

@wbmc
Copy link
Collaborator

wbmc commented Sep 29, 2023

@wbmc I have granted you the write access

OK, Thanks!

std::string dist_service_addr =
runtime::sys_util::GetEnvString("PJRT_DIST_SERVICE_ADDR", "");
runtime::sys_util::GetEnvString("MASTER_ADDR", "127.0.0.1") + ":" + port;
Copy link
Collaborator

@miladm miladm Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better readability, can we introduce a variable to describe this IP address is the default parameter? e.g. LOCAL_HOST_IP_DEFAULT

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@vanbasten23 vanbasten23 marked this pull request as ready for review October 11, 2023 00:01
Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great Xiongfei!

auto distributed_client =
MaybeInitializeDistributedRuntimeClient(local_rank, dist_service_addr);
MaybeInitializeDistributedRuntimeClient(global_rank);
auto allowed_devices =
std::make_optional<std::set<int>>(std::set{local_rank});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to generalize this to support CUDA_VISIBLE_DEVICES and single-process-multi-device

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we initially plan to incorporate CUDA_VISIBLE_DEVICES into this PR but we encountered some errors such as #5558 (comment) and #5558 (comment). We still plan to do it but probably in a follow-up pr

@@ -14,7 +15,8 @@ def num_local_processes() -> int:
"""
assert xenv.GPU_NUM_DEVICES in os.environ, \
"Must set `GPU_NUM_DEVICES` environment variable to use the PjRt GPU client"
return int(os.environ[xenv.GPU_NUM_DEVICES])
os.environ[xenv.WORLD_SIZE] = os.environ[xenv.GPU_NUM_DEVICES]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right? We'll clobber the torchrun-set world size. Also wondering if we need to keep GPU_NUM_DEVICES in the first place

Copy link
Collaborator Author

@vanbasten23 vanbasten23 Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is mainly to make single-host case work with spawn (e.g. the tests in pytorch/xla/test/pjrt/test_runtime_gpu.py in which we use spawn). To provide a similar UX as PyTorch, we should still support spawn for the single host case (fwiw, PyTorch supports it as https://screenshot.googleplex.com/7nKD68dXNUUskF7). But I like your idea of replacing "GPU_NUM_DEVICES" with "WORLD_SIZE".

Also, torchrun doesn't invoke spawn so this function wouldn't be called hence it doesn't overwrite the torchrun-set world size.

So how about I replace GPU_NUM_DEVICES with WORLD_SIZE ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WORLD_SIZE isn't quite right here, since this function returns the expected number of local processes. In torchrun, that's set as LOCAL_WORLD_SIZE. In xmp.spawn, we had to set it as something different like PJRT_LOCAL_WORLD_SIZE because LOCAL_WORLD_SIZE caused some issue with a third party package.

So maybe os.environ.get('LOCAL_WORLD_SIZE') or os.environ.get('PJRT_LOCAL_WORLD_SIZE')? It's clunky, but it covers both cases.

To use a subset of local GPUs with xmp.spawn, a user could set then LOCAL_WORLD_SIZE themselves instead of GPU_NUM_DEVICES.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that WORLD_SIZE is confusing. OTOH, WORLD_SIZE is what PyTorch single-node training uses: https://screenshot.googleplex.com/BH27HAYTpbNU4KA, if we want to stay closer to PyTorch.

I think clarity is more important here. How do you think we replace GPU_NUM_DEVICES with PJRT_LOCAL_WORLD_SIZE (or something like PT_XLA_LOCAL_WORLD_SIZE if we don't want to leak the underlying implementation detail), so the way we run single-host-multi-GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py ... turns to PJRT_LOCAL_WORLD_SIZE=4 python3 xla/test/test_train_mp_imagenet.p ..? @will-cromar @jonb377

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay to rely on the LOCAL_WORLD_SIZE env var here. We know torchrun will set it for sure, and the manual single-host-multi-GPU case can become PJRT_DEVICE=GPU LOCAL_WORLD_SIZE=4 python script.py*. We can get rid of GPU_NUM_DEVICES.

* As a follow up, I would like to implement more automatic configuration like we have with TPUs so users don't have to set anything in the default case

We definitely don't want to override torchrun's settings here.

I forgot that PJRT_LOCAL_WORLD_SIZE is set after this function is called (and probably just set to the output of this function). So we can ignore that variable here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I used LOCAL_WORLD_SIZE here.

I'll replace GPU_NUM_DEVICES with LOCAL_WORLD_SIZE in a follow-up PR.

@yeounoh yeounoh self-requested a review October 12, 2023 17:55
Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! I'm really excited to see your progress here.

While this PR is in review, can you start updating the documentation for GPUs in this repository? We should have some documentation covering how/why the GPU runtime works and make it clear that we expect people to use torchrun for multihost use cases.

@@ -14,7 +15,8 @@ def num_local_processes() -> int:
"""
assert xenv.GPU_NUM_DEVICES in os.environ, \
"Must set `GPU_NUM_DEVICES` environment variable to use the PjRt GPU client"
return int(os.environ[xenv.GPU_NUM_DEVICES])
os.environ[xenv.WORLD_SIZE] = os.environ[xenv.GPU_NUM_DEVICES]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay to rely on the LOCAL_WORLD_SIZE env var here. We know torchrun will set it for sure, and the manual single-host-multi-GPU case can become PJRT_DEVICE=GPU LOCAL_WORLD_SIZE=4 python script.py*. We can get rid of GPU_NUM_DEVICES.

* As a follow up, I would like to implement more automatic configuration like we have with TPUs so users don't have to set anything in the default case

We definitely don't want to override torchrun's settings here.

I forgot that PJRT_LOCAL_WORLD_SIZE is set after this function is called (and probably just set to the output of this function). So we can ignore that variable here.

@vanbasten23
Copy link
Collaborator Author

Great work! I'm really excited to see your progress here.

While this PR is in review, can you start updating the documentation for GPUs in this repository? We should have some documentation covering how/why the GPU runtime works and make it clear that we expect people to use torchrun for multihost use cases.

I added the documentation in #5704. Feel free to take a look as well. I'll do some testing for this feature and may make some fixes if necessary. Once the feature is more stable, I'll merge the GPU documentation PR.

Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple of minor nits still open. Otherwise LGTM!

Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Xiongfei!

@vanbasten23 vanbasten23 merged commit 6ea9947 into master Oct 19, 2023
ghpvnist pushed a commit to ghpvnist/xla that referenced this pull request Oct 31, 2023
* add prints

* to be continued.

* made torchrun works on single host

* Add an example of resnet torchrun

* add prints

* use local rank for allowed_devices.

* remove unwanted comments

* remove comments

* Add torchrun test to the CI.

* added a ll_reduce test

* fix ci failures

* remove some comments

* provide an alternative way to set the port for coordinator.

* fix test by destroying the process group after the test

* fix the single host test.

* fix single host gpu tests.

* add reduce scatter test

* fix comments

* fix a comment

* fix comments

* fix linter

* fix comments

* Use Local_WORLD_SIZE for spawn case.

* fix more comments
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Nov 16, 2023
* add prints

* to be continued.

* made torchrun works on single host

* Add an example of resnet torchrun

* add prints

* use local rank for allowed_devices.

* remove unwanted comments

* remove comments

* Add torchrun test to the CI.

* added a ll_reduce test

* fix ci failures

* remove some comments

* provide an alternative way to set the port for coordinator.

* fix test by destroying the process group after the test

* fix the single host test.

* fix single host gpu tests.

* add reduce scatter test

* fix comments

* fix a comment

* fix comments

* fix linter

* fix comments

* Use Local_WORLD_SIZE for spawn case.

* fix more comments
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
* add prints

* to be continued.

* made torchrun works on single host

* Add an example of resnet torchrun

* add prints

* use local rank for allowed_devices.

* remove unwanted comments

* remove comments

* Add torchrun test to the CI.

* added a ll_reduce test

* fix ci failures

* remove some comments

* provide an alternative way to set the port for coordinator.

* fix test by destroying the process group after the test

* fix the single host test.

* fix single host gpu tests.

* add reduce scatter test

* fix comments

* fix a comment

* fix comments

* fix linter

* fix comments

* Use Local_WORLD_SIZE for spawn case.

* fix more comments
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* add prints

* to be continued.

* made torchrun works on single host

* Add an example of resnet torchrun

* add prints

* use local rank for allowed_devices.

* remove unwanted comments

* remove comments

* Add torchrun test to the CI.

* added a ll_reduce test

* fix ci failures

* remove some comments

* provide an alternative way to set the port for coordinator.

* fix test by destroying the process group after the test

* fix the single host test.

* fix single host gpu tests.

* add reduce scatter test

* fix comments

* fix a comment

* fix comments

* fix linter

* fix comments

* Use Local_WORLD_SIZE for spawn case.

* fix more comments
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* add prints

* to be continued.

* made torchrun works on single host

* Add an example of resnet torchrun

* add prints

* use local rank for allowed_devices.

* remove unwanted comments

* remove comments

* Add torchrun test to the CI.

* added a ll_reduce test

* fix ci failures

* remove some comments

* provide an alternative way to set the port for coordinator.

* fix test by destroying the process group after the test

* fix the single host test.

* fix single host gpu tests.

* add reduce scatter test

* fix comments

* fix a comment

* fix comments

* fix linter

* fix comments

* Use Local_WORLD_SIZE for spawn case.

* fix more comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants