Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed #15361

Closed
Toni-SM opened this issue Apr 2, 2023 · 35 comments
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs

Comments

@Toni-SM
Copy link

Toni-SM commented Apr 2, 2023

Description

I have a python virtual environment with a clean installation of JAX

# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

When I run my scripts, they work perfectly, but sometimes I get the following error with a success rate of between 2 and 10 successful executions and between 1 and 3 failed executions

2023-04-02 16:00:19.964652: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-04-02 16:00:19.964737: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:438] Possibly insufficient driver version: 530.30.2
Traceback (most recent call last):
  File "ddpg_jax_gymnasium_pendulum.py", line 73, in <module>
    key = jax.random.PRNGKey(0)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/random.py", line 136, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 270, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 561, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 817, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 573, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 578, in random_seed_impl_base
    return seed(seeds)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/prng.py", line 813, in threefry_seed
    lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 458, in shift_right_logical
    return shift_right_logical_p.bind(x, y)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/core.py", line 817, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/util.py", line 253, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/util.py", line 246, in cached
    return f(*args, **kwargs)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
    return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
    self._executable = UnloadedMeshExecutable.from_hlo(
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/toni/Documents/SKRL/envs/env_jax/lib/python3.8/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

What jax/jaxlib version are you using?

jax 0.4.8, jaxlib 0.4.7+cuda12.cudnn88

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.8.10, Ubuntu 20.04

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3080 L...    On | 00000000:01:00.0 Off |                  N/A |
| N/A   38C    P3               N/A /  55W|     10MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1528      G   /usr/lib/xorg/Xorg                            4MiB |
|    0   N/A  N/A      2435      G   /usr/lib/xorg/Xorg                            4MiB |
+---------------------------------------------------------------------------------------+

CUDNN version (/usr/local/cuda/include/cudnn_version.h)

#define CUDNN_MAJOR 8
#define CUDNN_MINOR 8
#define CUDNN_PATCHLEVEL 1
@Toni-SM Toni-SM added the bug Something isn't working label Apr 2, 2023
@nouiz
Copy link
Collaborator

nouiz commented Apr 3, 2023

What is your OS?
Can you confirm you run the scripts sequentially and so there is nothing that is using the GPU in parallel?

@nouiz nouiz added the NVIDIA GPU Issues specific to NVIDIA GPUs label Apr 3, 2023
@Toni-SM
Copy link
Author

Toni-SM commented Apr 3, 2023

Hi @nouiz

The OS is Ubuntu 20.04, as indicated above.

Btw, I think the problem may be VS Code.
After running the script several times to try to get the error to appear, I see that the error only appears (not always but) when I make a modification to the script and save it.

There is also the following log.
As you can see (by running the nvidia-smi command just before executing the script, and after saving it) there is a GPU consumption. The strange thing is that the consumption comes from the python environment (env_gym) configured in the VS Code bottom right pane and not from the python of the sourced environment where jax is installed (env_jax) 🤔

(env_jax) toni@HP-ZBook-Studio-G8:~/Documents/SKRL/skrl/docs/source/examples/gymnasium$ nvidia-smi; python ddpg_jax_gymnasium_pendulum.py 
Mon Apr  3 20:41:40 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3080 L...    On | 00000000:01:00.0 Off |                  N/A |
| N/A   49C    P3               23W /  55W|  12453MiB / 16384MiB |      4%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1536      G   /usr/lib/xorg/Xorg                            4MiB |
|    0   N/A  N/A      2456      G   /usr/lib/xorg/Xorg                            4MiB |
|    0   N/A  N/A     27050      C   .../SKRL/envs/env_gym/bin/python          12440MiB |
+---------------------------------------------------------------------------------------+
[skrl:INFO] Environment class: gymnasium.core.Wrapper, gymnasium.utils.record_constructor.RecordConstructorArgs
[skrl:INFO] Environment wrapper: Gymnasium
2023-04-03 20:41:43.310989: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-04-03 20:41:43.311060: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:438] Possibly insufficient driver version: 530.30.2
Traceback (most recent call last):
...

@Toni-SM Toni-SM closed this as completed Apr 4, 2023
@nouiz
Copy link
Collaborator

nouiz commented Apr 4, 2023

Thanks for the results.
I think we need a way to give a better error to end users.

@nouiz nouiz reopened this Apr 4, 2023
@nouiz
Copy link
Collaborator

nouiz commented Apr 25, 2023

Recently a few error message got a little bit better.
Closing as I'm not sure what do to more.
But if the issue appear again and the error isn't good enough, poke us again.

@nouiz nouiz closed this as completed Apr 25, 2023
@amacrutherford
Copy link

I also got this error and it was due to GPU reaching its memory limit

@nouiz
Copy link
Collaborator

nouiz commented May 10, 2023

@amacrutherford Do you have the full error message you had?
I would like to improve the error message in that case.

@Bailey-24
Copy link

same error

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
[<ipython-input-25-5a28263ee724>](https://localhost:8080/#) in <cell line: 5>()
      3 
      4 # Initialize model weights using dummy tensors.
----> 5 rng = jax.random.PRNGKey(0)
      6 rng, key = jax.random.split(rng)
      7 init_img = jnp.ones((4, 224, 224, 5), jnp.float32)

22 frames
[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in backend_compile(backend, built_c, options, host_callbacks)
    469   # TODO(sharadmv): remove this fallback when all backends allow `compile`
    470   # to take in `host_callbacks`
--> 471   return backend.compile(built_c, compile_options=options)
    472 
    473 _ir_dump_counter = itertools.count()

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

I think there also the memory reaching the limit.

Fri May 12 01:50:33 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   57C    P0    27W /  70W |  15101MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

@amacrutherford
Copy link

Yep I received the same error message as @Bailey-24

@nouiz
Copy link
Collaborator

nouiz commented May 12, 2023

Thanks for the extended error message.
But can you share the full output without any truncation? There should be information that should help me.
Which JAX version did you use?

@hosseybposh
Copy link

hosseybposh commented May 13, 2023

I'm having the same problem but for me it's consistent and I'm unable to run simple Jax code. I only have this problem on my newest system with 4x RTX 4090 GPUs. I have a server A100 and a PC with a 3090ti that work smoothly.
Ubuntu 22 across all systems. First installed CUDA 11 from conda-forge as suggested, same issue. Then switched to loca installation of CUDA and cudnn. Same problem.

After a fresh installation of everything when I run a = jnp.ones((3,)) I get this error:

2023-05-13 09:04:27.790057: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-05-13 09:04:27.790140: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 5853872128 bytes free, 25393692672 bytes total.
Traceback (most recent call last):
File "", line 1, in
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2122, in ones
return lax.full(shape, 1, _jnp_dtype(dtype))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1203, in full
return broadcast(fill_value, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 768, in broadcast
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 796, in broadcast_in_dim
return broadcast_in_dim_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/core.py", line 380, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/core.py", line 790, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py", line 131, in apply_primitive
compiled_fun = xla_primitive_callable(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper
return cached(config._trace_context(), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py", line 222, in xla_primitive_callable
compiled = _xla_callable_uncached(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py", line 252, in _xla_callable_uncached
return computation.compile().unsafe_call
^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2633, in from_hlo
xla_executable, compile_options = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py", line 495, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/hoss/anaconda3/envs/jaxtf/lib/python3.11/site-packages/jax/_src/dispatch.py", line 463, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

@hosseybposh
Copy link

hosseybposh commented May 13, 2023

I've also tried to swap my PCs 3090ti (which works properly) with one of the 4090s and I got the exact same error. This used to work in the past.
I'm pretty sure the GPU hardware is not the problem and they are functional (tested them on Windows machines).

@nouiz
Copy link
Collaborator

nouiz commented May 15, 2023

Did you cut the output? The error tell to look above for more errors.
If there is more outputs, give me all what you have. I'll filter what is useful or not.

@ampolloreno
Copy link

I'm getting the same kind of error trying to install jax/jaxlib on an EC2 p2.xlarge (with k80s), to provide solidarity! I can provide more details if useful, but basically running some vanilla installation script of Anaconda and trying different variants of pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html leads Jax to report seeing the GPU when I check print(xla_bridge.get_backend().platform) but gives the DNN error above, otherwise.

@ampolloreno
Copy link

(I'm also unable to any Jax code, e.g. a = jnp.ones((3,)).)

@hawkinsp
Copy link
Collaborator

@ampolloreno Please open new issues rather than appending onto closed ones.

However, I think the problem in your case is simple: JAX no longer supports Kepler GPUs in the wheels we release. You can probably rebuild jaxlib from source if you need Kepler support, but note NVIDIA has dropped Kepler support from CUDA 12 and CUDNN 8.9, so this may not remain true for long.

@hosseybposh
Copy link

@nouiz no this is all the output. It's several lines of error are you seeing all of it?

I managed to resolve this though. I installed CUDA 11 and cudnn 8.6. In my experiments I also installed the latest version of everything but this was the only version combination that worked for me. Now I'm getting other errors but that's a different problem.

@ampolloreno
Copy link

@hawkinsp Point taken and... Thanks for the help! I just switched over to V100s and voila!

@liyc-ai
Copy link

liyc-ai commented Jul 12, 2023

I got the same error, maybe it due to the mismatch between your cuda version and the installed jax. I use ubuntu 20.04 with cuda version as below:
image
At start, I installed the newest jax as

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Then I got the error as reported. So I switched to

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Everything works well!

@tuzhucheng
Copy link

Thanks @hosseybposh, for a simple use case I was able to use JAX 0.4.13 and CUDA 11.8 with CUDNN 8.6. I needed to add /usr/lib/x86_64-linux-gnu to the LD_LIBRARY_PATH (installed libcudnn8 with apt-get).

@TInaWangxue
Copy link

I also met this error:
The output:
2023-07-31 01:53:45.016563: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.

XlaRuntimeError Traceback (most recent call last)
Cell In[4], line 29
26 model = trainer.make_model(nmask)
28 lr_fn, opt = trainer.make_optimizer(steps_per_epoch=len(train_dl))
---> 29 state = trainer.create_train_state(jax.random.PRNGKey(0), model, opt)
30 state = checkpoints.restore_checkpoint(ckpt.parent, state)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/random.py:137, in PRNGKey(seed)
134 if np.ndim(seed):
135 raise TypeError("PRNGKey accepts a scalar seed, but was given an array of"
136 f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
--> 137 key = prng.seed_with_impl(impl, seed)
138 return _return_prng_keys(True, key)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:320, in seed_with_impl(impl, seed)
319 def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl:
--> 320 return random_seed(seed, impl=impl)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:734, in random_seed(seeds, impl)
732 else:
733 seeds_arr = jnp.asarray(seeds)
--> 734 return random_seed_p.bind(seeds_arr, impl=impl)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:380, in Primitive.bind(self, *args, **params)
377 def bind(self, *args, **params):
378 assert (not config.jax_enable_checks or
379 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 380 return self.bind_with_trace(find_top_trace(args), args, params)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:383, in Primitive.bind_with_trace(self, trace, args, params)
382 def bind_with_trace(self, trace, args, params):
--> 383 out = trace.process_primitive(self, map(trace.full_raise, args), params)
384 return map(full_lower, out) if self.multiple_results else full_lower(out)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:790, in EvalTrace.process_primitive(self, primitive, tracers, params)
789 def process_primitive(self, primitive, tracers, params):
--> 790 return primitive.impl(*tracers, **params)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:746, in random_seed_impl(seeds, impl)
744 @random_seed_p.def_impl
745 def random_seed_impl(seeds, *, impl):
--> 746 base_arr = random_seed_impl_base(seeds, impl=impl)
747 return PRNGKeyArrayImpl(impl, base_arr)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:751, in random_seed_impl_base(seeds, impl)
749 def random_seed_impl_base(seeds, *, impl):
750 seed = iterated_vmap_unary(seeds.ndim, impl.seed)
--> 751 return seed(seeds)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:980, in threefry_seed(seed)
968 def threefry_seed(seed: typing.Array) -> typing.Array:
969 """Create a single raw threefry PRNG key from an integer seed.
970
971 Args:
(...)
978 first padding out with zeros).
979 """
--> 980 return _threefry_seed(seed)

[... skipping hidden 12 frame]

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/dispatch.py:463, in backend_compile(backend, module, options, host_callbacks)
458 return backend.compile(built_c, compile_options=options,
459 host_callbacks=host_callbacks)
460 # Some backends don't have host_callbacks option yet
461 # TODO(sharadmv): remove this fallback when all backends allow compile
462 # to take in host_callbacks
--> 463 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

What jax/jaxlib version are you using?

Jax0.4.10, jaxlib0.4.10+cuda11.cudnn86

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.11.4, Ubuntu 22.04

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:18:00.0 Off | N/A |
| 30% 33C P8 22W / 350W | 8688MiB / 12288MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce ... On | 00000000:3B:00.0 Off | N/A |
| 30% 31C P8 14W / 350W | 8MiB / 12288MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 2 NVIDIA GeForce ... On | 00000000:86:00.0 Off | N/A |
| 30% 34C P8 24W / 350W | 8MiB / 12288MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 3 NVIDIA GeForce ... On | 00000000:AF:00.0 Off | N/A |
| 30% 30C P8 8W / 350W | 8MiB / 12288MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 2861 C+G ...ome-remote-desktop-daemon 249MiB |
| 0 N/A N/A 3213874 C ...ransformer_117/bin/python 8430MiB |
| 1 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB |
| 2 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB |
| 3 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB |
+-----------------------------------------------------------------------------+

@hawkinsp
Copy link
Collaborator

@TInaWangxue's problem was resolved in #16901.

@cloudinging
Copy link

hi, I have similar issue. please help me!

the output:
2023-09-05 14:32:56.559501: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-09-05 14:32:56.559528: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 6081413120 bytes free, 25438126080 bytes total.
Traceback (most recent call last):
File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 453, in
app.run(main)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 428, in main
predict_structure(
File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 214, in predict_structure
prediction_result = model_runner.predict(processed_feature_dict,
File "/home/wangyun/pre/alphafold-multimer-main/alphafold/model/model.py", line 167, in predict
result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/random.py", line 137, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 320, in seed_with_impl
return random_seed(seed, impl=impl)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 732, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 744, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 749, in random_seed_impl_base
return seed(seeds)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 978, in threefry_seed
return _threefry_seed(seed)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 208, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 155, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 2633, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 790, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 1085, in _pjit_call_impl
compiled = _pjit_lower(
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2633, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/dispatch.py", line 494, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/dispatch.py", line 462, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

What jax/jaxlib version are you using?
jax 0.4.9, jaxlib 0.4.9+cuda11.cudnn86

My conda virtual environment:
python3.9.0
cudatoolkit 11.8.0 h4ba93d1_12 conda-forge
cudnn 8.6.0.163 hed8a83a_0 cudistas

But my OS environment:
NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2
ii cudnn-local-repo-ubuntu1804-8.9.3.28 1.0-1 amd64 cudnn-local repository configuration files

I try everything what I can , but ……

@April-ppigg
Copy link

What is your OS? Can you confirm you run the scripts sequentially and so there is nothing that is using the GPU in parallel?

Hi, I encountered the same problem. When I use A100 to run a single task, it can run normally, but when I submit two tasks at the same time, the above error will be reported. So the reason is that A100 runs two tasks at the same time, will there be a conflict?

@nouiz
Copy link
Collaborator

nouiz commented Oct 16, 2023

Hi, I encountered the same problem. When I use A100 to run a single task, it can run normally, but when I submit two tasks at the same time, the above error will be reported. So the reason is that A100 runs two tasks at the same time, will there be a conflict?

I suppose 2 tasks means 2 process. If not, tell us.
By default, JAX will reserve 75% of the GPU memory for the process:
https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

So the 2nd process will end up missing GPU memory most of the time. Read that web page to know how to control that 75% memory allocation. If you can lower it to 45% and the first process has enough memory, it will probably work. Otherwise, try a few other values.

@cloudinging
Copy link

hi, I have similar issue. please help me!

the output: 2023-09-05 14:32:56.559501: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR 2023-09-05 14:32:56.559528: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 6081413120 bytes free, 25438126080 bytes total. Traceback (most recent call last): File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 453, in app.run(main) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 428, in main predict_structure( File "/home/wangyun/pre/alphafold-multimer-main/run_alphafold.py", line 214, in predict_structure prediction_result = model_runner.predict(processed_feature_dict, File "/home/wangyun/pre/alphafold-multimer-main/alphafold/model/model.py", line 167, in predict result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/random.py", line 137, in PRNGKey key = prng.seed_with_impl(impl, seed) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 320, in seed_with_impl return random_seed(seed, impl=impl) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 732, in random_seed return random_seed_p.bind(seeds_arr, impl=impl) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 744, in random_seed_impl base_arr = random_seed_impl_base(seeds, impl=impl) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 749, in random_seed_impl_base return seed(seeds) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/prng.py", line 978, in threefry_seed return _threefry_seed(seed) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 208, in cache_miss outs, out_flat, out_tree, args_flat = _python_pjit_helper( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 155, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, **params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 2633, in bind return self.bind_with_trace(top_trace, args, params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/core.py", line 790, in process_primitive return primitive.impl(*tracers, **params) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/pjit.py", line 1085, in _pjit_call_impl compiled = _pjit_lower( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2633, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/dispatch.py", line 494, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/wangyun/miniconda3/envs/mutimer3/lib/python3.9/site-packages/jax/_src/dispatch.py", line 462, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

What jax/jaxlib version are you using? jax 0.4.9, jaxlib 0.4.9+cuda11.cudnn86

My conda virtual environment: python3.9.0 cudatoolkit 11.8.0 h4ba93d1_12 conda-forge cudnn 8.6.0.163 hed8a83a_0 cudistas

But my OS environment: NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2 ii cudnn-local-repo-ubuntu1804-8.9.3.28 1.0-1 amd64 cudnn-local repository configuration files

I try everything what I can , but ……

then, it work. you can look this link (https://blog.csdn.net/2201_75882736/article/details/132812927)

@William-HYWu
Copy link

Hi, I also have the same issue, could anyone please help me?
E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found.
Traceback (most recent call last):
File "/bd_byt4090i1/users/state_space_model/DNN/S5/run_train.py", line 101, in
train(parser.parse_args())
File "/bd_byt4090i1/users/state_space_model/DNN/S5/s5/train.py", line 41, in train
key = random.PRNGKey(args.jax_seed)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/random.py", line 160, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 406, in seed_with_impl
return random_seed(seed, impl=impl)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 690, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 702, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 707, in random_seed_impl_base
return seed(seeds)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 936, in threefry_seed
return _threefry_seed(seed)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind
return self.bind_with_trace(top_trace, args, params)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1123, in _pjit_call_impl_python
always_lower=False, lowering_platform=None).compile()
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/bd_byt4090i1/users/state_space_model/DNN/S5/run_train.py", line 101, in
train(parser.parse_args())
File "/bd_byt4090i1/users/state_space_model/DNN/S5/s5/train.py", line 41, in train
key = random.PRNGKey(args.jax_seed)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/random.py", line 160, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 406, in seed_with_impl
return random_seed(seed, impl=impl)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 690, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 380, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 702, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 707, in random_seed_impl_base
return seed(seeds)
File "/bd_byt4090i1/users/state_space_model/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/prng.py", line 936, in threefry_seed
return _threefry_seed(seed)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

My jax version:
jax==0.4.13
jaxlib==0.4.13+cuda11.cudnn86
flax==0.7.4
chex==0.1.8

My gpu information:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.223.02 Driver Version: 470.223.02 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:83:00.0 Off | N/A |
| 21% 35C P0 40W / 215W | 0MiB / 7982MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+

It should be a RTX 2070. Thanks a lot for the help

@nouiz
Copy link
Collaborator

nouiz commented Apr 1, 2024

You are using an old JAX version (0.4.13) and an old driver that is for CUDA 11.8.
Can you update your NVIDIA driver to at least one that support CUDA 11.8, as this is the min version that is currently supported by JAX? JAX is dropping CUDA 11 in the next releases, so if you can update to CUDA12, that would be better.

@William-HYWu
Copy link

You are using an old JAX version (0.4.13) and an old driver that is for CUDA 11.8.
Can you update your NVIDIA driver to at least one that support CUDA 11.8, as this is the min version that is currently supported by JAX? JAX is dropping CUDA 11 in the next releases, so if you can update to CUDA12, that would be better.

Thanks a lot. I tried that, and it worked!

@crshin
Copy link

crshin commented May 13, 2024

Hi, guys.
I'm here on by recommendation.
I'm facing a similar issue: "Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed."
I tried almost everything suggested on these page to resolve the GPU memory problem:
YoshitakaMo/localcolabfold#210
YoshitakaMo/localcolabfold#224
YoshitakaMo/localcolabfold#228

My current jax, cudnn, and nvidia-smi versions are as follows.
(Linux Ubuntu 22.04.2 LTS and RTX 4090)

$nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:19:38_PST_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0

$python3.10 -m pip list | grep jax
jax                      0.4.23
jax-cuda12-pjrt          0.4.23
jax-cuda12-plugin        0.4.23
jaxlib                   0.4.23+cuda12.cudnn89

$python3.10 -m pip list | grep cudnn
jaxlib                   0.4.23+cuda12.cudnn89
nvidia-cudnn-cu12        9.1.0.70

$nvidia-smi
Fri May 10 07:20:32 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+

Here's the problem I'm encountering:

$colabfold_batch --templates --amber test_A.fasta ./
2024-05-13 06:17:37,861 Running colabfold 1.5.5 (57b220e028610ba7331ebe1ef9c2d0419992469a)
2024-05-13 06:17:38,189 Running on GPU
2024-05-13 06:17:38,738 Found 9 citations for tools or databases
2024-05-13 06:17:38,738 Query 1/1: pdb_A (length 108)
2024-05-13 06:17:41,729 Sequence 0 found templates: ['1m4u_L', '2r52_B', '6oml_Y', '5vt2_B', '2r53_A', '1lxi_A', '4n1d_A', '7zjf_B', '7zjf_A', '6z3g_A', '3qb4_C', '3qb4_A', '6z3j_A', '2h64_A', '4uhy_A', '1reu_A', '4ui0_A', '2h62_B', '4mid_A', '3bk3_B']
2024-05-13 06:17:42,533 Setting max_seq=512, max_extra_seq=5120
2024-05-13 06:17:42,674 Could not predict pdb_A. Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
2024-05-13 06:17:42,674 Done

I can't find "the errors above."

Can someone offer any guesses or solutions?

@hawkinsp
Copy link
Collaborator

JAX releases don't support CUDNN 9 yet. Downgrade to CUDNN 8.9 (or build jaxlib from source with CUDNN 9, that works also)

@AWangji
Copy link

AWangji commented May 17, 2024

I got the same error, maybe it due to the mismatch between your cuda version and the installed jax. I use ubuntu 20.04 with cuda version as below: image At start, I installed the newest jax as

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Then I got the error as reported. So I switched to

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Everything works well!

but jaxlib has no cuda11
image

@Lvchangze
Copy link

I got the same error, maybe it due to the mismatch between your cuda version and the installed jax. I use ubuntu 20.04 with cuda version as below: image At start, I installed the newest jax as

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Then I got the error as reported. So I switched to

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Everything works well!

Great!

@agnikumar
Copy link

I'm getting
jaxlib.xla_extension.XlaRuntimeError: FAILED PRECONDITION: DNN library initialization failed.

Checking if there are any ideas about how to resolve this? JAX version is 0.4.26, CUDA version is 12.4, and cuDNN version is 8.9.7.29, which should be compatible.
image

@TTy32
Copy link

TTy32 commented Dec 6, 2024

I'm getting jaxlib.xla_extension.XlaRuntimeError: FAILED PRECONDITION: DNN library initialization failed.

Checking if there are any ideas about how to resolve this? JAX version is 0.4.26, CUDA version is 12.4, and cuDNN version is 8.9.7.29, which should be compatible. image

I had the same issue. Solved by building clean env Python 3.10, CUDA 12.1, Pytorch 2.5.1+cu121).
Make sure to LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH

Test:

python -c "import torch; print(torch.__version__, torch.version.cuda)"

@MikaBell
Copy link

MikaBell commented Jan 2, 2025

I'm having the same issue, please can someone help?
this is the error i get:

y = jnp.dot(x, x)
E0102 13:38:39.485608 240885 pjrt_stream_executor_client.cc:3085] Execution of replica 0 failed: INTERNAL: the library was not initialized
Traceback (most recent call last):
File "", line 1, in
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: the library was not initialized


For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Sep_12_02:18:05_PDT_2024
Cuda compilation tools, release 12.6, V12.6.77
Build cuda_12.6.r12.6/compiler.34841621_0

nvidia-smi
Thu Jan 2 13:39:26 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.02 Driver Version: 560.94 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4080 On | 00000000:01:00.0 On | N/A |
| 0% 29C P8 13W / 340W | 3247MiB / 16376MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 23 G /Xwayland N/A |
+-----------------------------------------------------------------------------------------+

pip show jax jaxlib
Name: jax
Version: 0.4.38
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/mika1309/.pyenv/versions/3.12.0/lib/python3.12/site-packages
Requires: jaxlib, ml_dtypes, numpy, opt_einsum, scipy
Required-by:

Name: jaxlib
Version: 0.4.38
Summary: XLA library for JAX
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/mika1309/.pyenv/versions/3.12.0/lib/python3.12/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: jax

python --version
Python 3.12.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs
Projects
None yet
Development

No branches or pull requests