Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Set all_reduce_token to None when exiting" #6321

Merged
merged 2 commits into from
Jan 19, 2024

Conversation

vanbasten23
Copy link
Collaborator

@vanbasten23 vanbasten23 commented Jan 18, 2024

Reverts #6247

See #6320 and #6247 (comment) for detail. Also, we observed error in the GPU CI result in #6247 even though the GPU CI didn't catch it. Need to investigate

@vanbasten23 vanbasten23 requested a review from JackCaoG January 18, 2024 19:13
@vanbasten23 vanbasten23 marked this pull request as ready for review January 18, 2024 19:13
@JackCaoG
Copy link
Collaborator

Let's wait and manually verify all GPU tests.

@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented Jan 18, 2024

The GPU CI fails but the CI log didn't show any failing test. Though I can see the error:

2024-01-18T20:42:31.8381044Z *** Received signal 11 ***
2024-01-18T20:42:31.8381784Z *** BEGIN MANGLED STACK TRACE ***
2024-01-18T20:42:31.8401794Z /opt/conda/lib/python3.8/site-packages/torch_xla-2.2.0+git0f52f22-py3.8-linux-x86_64.egg/_XLAC.cpython-38-x86_64-linux-gnu.so(+0x5f80ab6)[0x7f4f7a695ab6]
2024-01-18T20:42:31.8403646Z /lib/x86_64-linux-gnu/libpthread.so.0(+0x13140)[0x7f510e284140]
2024-01-18T20:42:31.8404479Z [0x61a1990]
2024-01-18T20:42:31.8404980Z *** END MANGLED STACK TRACE ***
2024-01-18T20:42:31.8405412Z 
2024-01-18T20:42:31.8433341Z *** Begin stack trace ***
2024-01-18T20:42:31.8434232Z 	tsl::CurrentStackTrace[abi:cxx11]()
2024-01-18T20:42:31.8434870Z 	
2024-01-18T20:42:31.8435285Z 	
2024-01-18T20:42:31.8435961Z 	
2024-01-18T20:42:31.8436392Z *** End stack trace ***
2024-01-18T20:42:32.7490678Z ./test/run_tests.sh: line 48: 109476 Aborted                 (core dumped) python3 "$@"

Will try to reproduce on the last run test pytorch/xla/test/test_zero1.py locally.

cc @will-cromar @yeounoh @JackCaoG

@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented Jan 18, 2024

Running the last run test in the CI gave me:

root@xiowei-gpu:/ansible# GPU_NUM_DEVICES=4  PJRT_DEVICE=CUDA python pytorch/xla/test/test_zero1.py
.
----------------------------------------------------------------------
Ran 1 test in 1.334s

OK
Segmentation fault (core dumped)
root@xiowei-gpu:/ansible#

The test gives OK result but it actually failed.

The error

2024-01-18T20:42:31.8381044Z *** Received signal 11 ***
2024-01-18T20:42:31.8381784Z *** BEGIN MANGLED STACK TRACE ***
2024-01-18T20:42:31.8401794Z /opt/conda/lib/python3.8/site-packages/torch_xla-2.2.0+git0f52f22-py3.8-linux-x86_64.egg/_XLAC.cpython-38-x86_64-linux-gnu.so(+0x5f80ab6)[0x7f4f7a695ab6]
2024-01-18T20:42:31.8403646Z /lib/x86_64-linux-gnu/libpthread.so.0(+0x13140)[0x7f510e284140]
2024-01-18T20:42:31.8404479Z [0x61a1990]
2024-01-18T20:42:31.8404980Z *** END MANGLED STACK TRACE ***

also appeared for other tests such as PJRT_DEVICE=CUDA python pytorch/xla/test/pjrt/test_ddp.py but I was not able to reproduce the test failure.

@vanbasten23
Copy link
Collaborator Author

I also checked the GPU CI log:

Before this PR, the GPU CI actually fails with error:

2024-01-18T15:18:06.9360003Z [       OK ] TestExperimentalPjrtGpu.test_spawn_xmp
2024-01-18T15:18:06.9361503Z ----------------------------------------------------------------------
2024-01-18T15:18:06.9362290Z Ran 17 tests in 130.401s
2024-01-18T15:18:06.9362675Z 
2024-01-18T15:18:06.9363205Z OK
2024-01-18T15:18:06.9510616Z Warning, backtrace signal handler for signal 11 overwrote previous handler.
2024-01-18T15:18:06.9511893Z Warning, backtrace signal handler for signal 6 overwrote previous handler.
2024-01-18T15:18:06.9513090Z Warning, backtrace signal handler for signal 7 overwrote previous handler.
2024-01-18T15:18:06.9514291Z Warning, backtrace signal handler for signal 4 overwrote previous handler.
2024-01-18T15:18:06.9515473Z Warning, backtrace signal handler for signal 8 overwrote previous handler.
2024-01-18T15:18:06.9553467Z E0118 15:18:06.954417871   37428 server_chttp2.cc:40]        ***"created":"@1705591086.954362503","description":"Only 1 addresses added out of total 2 resolved","file":"external/com_github_grpc_grpc/src/core/ext/transport/chttp2/server/chttp2_server.cc","file_line":404,"referenced_errors":[***"created":"@1705591086.954354914","description":"Address family not supported by protocol","errno":97,"file":"external/com_github_grpc_grpc/src/core/lib/iomgr/socket_utils_common_posix.cc","file_line":420,"os_error":"Address family not supported by protocol","syscall":"socket","target_address":"[::1]:8547"***]***
2024-01-18T15:25:06.9876311Z 2024-01-18 15:25:06.986850: E external/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc:493] Failed to disconnect from coordination service with status: DEADLINE_EXCEEDED: Deadline Exceeded
2024-01-18T15:25:06.9879338Z Additional GRPC error information from remote target unknown_target_for_coordination_leader while calling /tensorflow.CoordinationService/ShutdownTask:
2024-01-18T15:25:06.9882338Z :***"created":"@1705591506.986685133","description":"Error received from peer ipv4:127.0.0.1:8547","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Deadline Exceeded","grpc_status":4***
2024-01-18T15:25:06.9885977Z Proceeding with agent shutdown anyway. This is usually caused by an earlier error during execution. Check the logs (this task or the leader) for an earlier error to debug further.
2024-01-18T15:25:06.9889256Z 2024-01-18 15:25:06.987327: E external/tsl/tsl/distributed_runtime/coordination/coordination_service.cc:1169] Shutdown barrier in coordination service has failed:
2024-01-18T15:25:06.9897098Z ABORTED: Barrier failed because service is shutting down. Barrier_id: Shutdown::435425134344508282 [type.googleapis.com/tensorflow.CoordinationServiceError='']
2024-01-18T15:25:06.9900503Z This suggests that the workers are out of sync. Either at least one worker is too fast in its execution / crashed early or too slow / hanging. Check the logs for an earlier error to identify the root cause.
2024-01-18T15:25:06.9907406Z 2024-01-18 15:25:06.987385: E external/tsl/tsl/distributed_runtime/coordination/coordination_service.cc:762] INTERNAL: Shutdown barrier has been passed with status: 'ABORTED: Barrier failed because service is shutting down. Barrier_id: Shutdown::435425134344508282 [type.googleapis.com/tensorflow.CoordinationServiceError='']', but this task is not at the barrier yet. [type.googleapis.com/tensorflow.CoordinationServiceError='']
2024-01-18T15:25:06.9917835Z Error in atexit._run_exitfuncs:
2024-01-18T15:25:06.9918574Z Traceback (most recent call last):
2024-01-18T15:25:06.9920225Z   File "/opt/conda/lib/python3.8/site-packages/torch_xla-2.2.0+gita8b27eb-py3.8-linux-x86_64.egg/torch_xla/__init__.py", line 151, in _prepare_to_exit
2024-01-18T15:25:06.9922006Z     device = _XLAC._xla_get_default_device()
2024-01-18T15:25:06.9924094Z RuntimeError: torch_xla/csrc/runtime/xla_coordinator.cc:28 : Check failed: dist_runtime_client_->Connect().ok() 
2024-01-18T15:25:06.9925439Z *** Begin stack trace ***
2024-01-18T15:25:06.9926066Z 	tsl::CurrentStackTrace[abi:cxx11]()
2024-01-18T15:25:06.9927863Z 	torch_xla::runtime::XlaCoordinator::XlaCoordinator(int, int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)
2024-01-18T15:25:06.9930218Z 	torch_xla::runtime::InitializePjRt(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
2024-01-18T15:25:06.9931630Z 	torch_xla::runtime::PjRtComputationClient::PjRtComputationClient()
2024-01-18T15:25:06.9932466Z 	
2024-01-18T15:25:06.9932986Z 	torch_xla::runtime::GetComputationClient()
2024-01-18T15:25:06.9933726Z 	torch_xla::bridge::GetDefaultDevice()
2024-01-18T15:25:06.9934423Z 	torch_xla::bridge::GetCurrentDevice()
2024-01-18T15:25:06.9935159Z 	torch_xla::bridge::GetCurrentAtenDevice()
2024-01-18T15:25:06.9935825Z 	
2024-01-18T15:25:06.9936243Z 	
2024-01-18T15:25:06.9936682Z 	PyCFunction_Call

for example.

With this PR, I don't see the error message such as Check failed: dist_runtime_client_->Connect().ok() in the GPU CI log.

So after this PR is merged, the next step is to fix GPU CI so the error won't go unnoticed.

@vanbasten23
Copy link
Collaborator Author

Note, I also need to revert #6268 to make the CI pass. #6268 was backported to r2.2 as well. @ManfeiBai @JackCaoG . So we may want to revert it in the r2.2 release branch

@yitongh
Copy link
Contributor

yitongh commented Jan 19, 2024

@vanbasten23 sorry for #6247 break the ci. I think the error is due to the ComputationClient has exited in atexit._run_exitfuncs when using xla_multiprocessing, while the client still exists when using torchrun. So the better solution is setting the all reduce token in PrepareToExit, like this:

diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py
index 8d4997e28..d753f8f7c 100644
--- a/torch_xla/__init__.py
+++ b/torch_xla/__init__.py
@@ -148,8 +148,6 @@ _aws_ec2_inf_trn_init()


 def _prepare_to_exit():
-  device = _XLAC._xla_get_default_device()
-  _XLAC._set_all_reduce_token(device, None)
   _XLAC._prepare_to_exit()
   if int(os.environ.get('PT_XLA_DEBUG', '0')):
     _summarize_fn_tracker()
diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp
index 3281f0e9a..b255bb043 100644
--- a/torch_xla/csrc/init_python_bindings.cpp
+++ b/torch_xla/csrc/init_python_bindings.cpp
@@ -97,6 +97,8 @@ void PrepareToExit() {
   runtime::ComputationClient* client =
       runtime::GetComputationClientIfInitialized();
   if (client != nullptr) {
+    auto xla_device = GetDeviceOrCurrent("");
+    SetAllReduceToken(xla_device, nullptr);
     XLAGraphExecutor::Get()->WaitDeviceOps({});
   }
 }

#6247 fixes the issue in #6246, which is a common scenario during testing.

@yitongh
Copy link
Contributor

yitongh commented Jan 19, 2024

BTW, when using xla_multiprocessing, #6246 will not arise. The mp_test_early_exit.py is not needed. This test should be placed in test/pjrt/test_torchrun.py.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants