-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed #15361
Comments
What is your OS? |
Hi @nouiz The OS is Ubuntu 20.04, as indicated above. Btw, I think the problem may be VS Code. There is also the following log.
|
Thanks for the results. |
Recently a few error message got a little bit better. |
I also got this error and it was due to GPU reaching its memory limit |
@amacrutherford Do you have the full error message you had? |
same error ---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
[<ipython-input-25-5a28263ee724>](https://localhost:8080/#) in <cell line: 5>()
3
4 # Initialize model weights using dummy tensors.
----> 5 rng = jax.random.PRNGKey(0)
6 rng, key = jax.random.split(rng)
7 init_img = jnp.ones((4, 224, 224, 5), jnp.float32)
22 frames
[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in backend_compile(backend, built_c, options, host_callbacks)
469 # TODO(sharadmv): remove this fallback when all backends allow `compile`
470 # to take in `host_callbacks`
--> 471 return backend.compile(built_c, compile_options=options)
472
473 _ir_dump_counter = itertools.count()
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details. I think there also the memory reaching the limit. Fri May 12 01:50:33 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 57C P0 27W / 70W | 15101MiB / 15360MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
+-----------------------------------------------------------------------------+ |
Yep I received the same error message as @Bailey-24 |
Thanks for the extended error message. |
I'm having the same problem but for me it's consistent and I'm unable to run simple Jax code. I only have this problem on my newest system with 4x RTX 4090 GPUs. I have a server A100 and a PC with a 3090ti that work smoothly. After a fresh installation of everything when I run 2023-05-13 09:04:27.790057: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR |
I've also tried to swap my PCs 3090ti (which works properly) with one of the 4090s and I got the exact same error. This used to work in the past. |
Did you cut the output? The error tell to look above for more errors. |
I'm getting the same kind of error trying to install jax/jaxlib on an EC2 p2.xlarge (with k80s), to provide solidarity! I can provide more details if useful, but basically running some vanilla installation script of Anaconda and trying different variants of |
(I'm also unable to any Jax code, e.g. |
@ampolloreno Please open new issues rather than appending onto closed ones. However, I think the problem in your case is simple: JAX no longer supports Kepler GPUs in the wheels we release. You can probably rebuild jaxlib from source if you need Kepler support, but note NVIDIA has dropped Kepler support from CUDA 12 and CUDNN 8.9, so this may not remain true for long. |
@nouiz no this is all the output. It's several lines of error are you seeing all of it? I managed to resolve this though. I installed CUDA 11 and cudnn 8.6. In my experiments I also installed the latest version of everything but this was the only version combination that worked for me. Now I'm getting other errors but that's a different problem. |
@hawkinsp Point taken and... Thanks for the help! I just switched over to V100s and voila! |
Thanks @hosseybposh, for a simple use case I was able to use JAX 0.4.13 and CUDA 11.8 with CUDNN 8.6. I needed to add |
I also met this error:
|
@TInaWangxue's problem was resolved in #16901. |
hi, I have similar issue. please help me! the output: What jax/jaxlib version are you using? My conda virtual environment: But my OS environment: I try everything what I can , but …… |
Hi, I encountered the same problem. When I use A100 to run a single task, it can run normally, but when I submit two tasks at the same time, the above error will be reported. So the reason is that A100 runs two tasks at the same time, will there be a conflict? |
I suppose 2 tasks means 2 process. If not, tell us. So the 2nd process will end up missing GPU memory most of the time. Read that web page to know how to control that 75% memory allocation. If you can lower it to 45% and the first process has enough memory, it will probably work. Otherwise, try a few other values. |
then, it work. you can look this link (https://blog.csdn.net/2201_75882736/article/details/132812927) |
Hi, I also have the same issue, could anyone please help me? The stack trace below excludes JAX-internal frames. The above exception was the direct cause of the following exception: Traceback (most recent call last): My jax version: My gpu information: +-----------------------------------------------------------------------------+ It should be a RTX 2070. Thanks a lot for the help |
You are using an old JAX version (0.4.13) and an old driver that is for CUDA 11.8. |
Thanks a lot. I tried that, and it worked! |
Hi, guys. My current jax, cudnn, and nvidia-smi versions are as follows.
Here's the problem I'm encountering:
I can't find "the errors above." Can someone offer any guesses or solutions? |
JAX releases don't support CUDNN 9 yet. Downgrade to CUDNN 8.9 (or build jaxlib from source with CUDNN 9, that works also) |
I'm having the same issue, please can someone help?
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. nvcc --version nvidia-smi +-----------------------------------------------------------------------------------------+ pip show jax jaxlib
|
Description
I have a python virtual environment with a clean installation of JAX
When I run my scripts, they work perfectly, but sometimes I get the following error with a success rate of between 2 and 10 successful executions and between 1 and 3 failed executions
What jax/jaxlib version are you using?
jax 0.4.8, jaxlib 0.4.7+cuda12.cudnn88
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.8.10, Ubuntu 20.04
NVIDIA GPU info
CUDNN version (
/usr/local/cuda/include/cudnn_version.h
)The text was updated successfully, but these errors were encountered: