Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cryptic XLA error when using JAX 0.4.26 #21396

Closed
itk22 opened this issue May 23, 2024 · 12 comments
Closed

Cryptic XLA error when using JAX 0.4.26 #21396

itk22 opened this issue May 23, 2024 · 12 comments
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs

Comments

@itk22
Copy link

itk22 commented May 23, 2024

Description

Dear JAX Team,

I am transitioning my code from jax 0.4.25 to jax 0.4.26 because I would like to use it with the latest NGC container and I am running into XLA errors which I cannot quite understand. Unfortunately, I could not reproduce the error with a small piece of code. However, I have observed that the error occurs only when using a batch size greater than one. Here is the error message:

2024-05-23 15:20:04.554952: F external/xla/xla/shape_tree.cc:54] Check failed: result->children_start_id >= 0 (0 vs. -1)

Could you please assist in diagnosing and resolving these issues? Any guidance would be greatly appreciated.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='IgorPC', release='5.15.146.1-microsoft-standard-WSL2', version='#1 SMP Thu Jan 11 04:09:03 UTC 2024', machine='x86_64')
@itk22 itk22 added the bug Something isn't working label May 23, 2024
@hawkinsp
Copy link
Collaborator

Can you please try with (a) a fresh virtualenv and (b) using jax[cuda12] v0.4.28, which is the current version?

@itk22
Copy link
Author

itk22 commented May 24, 2024

Hi @hawkinsp,

Thank you for the quick response. I followed your instructions but encountered the same error. To provide more context, the error occurs during batched operations with a model written in Equinox. Unfortunately, I have not been able to recreate it with simpler examples, leading me to believe the issue is not on Equinox's side. @patrick-kidger, have you encountered similar issues by any chance?

@patrick-kidger
Copy link
Collaborator

I've not seen this one before I'm afraid!

@hawkinsp
Copy link
Collaborator

hawkinsp commented May 24, 2024

Can you share an HLO dump? That might be enough for us to reproduce. Run with XLA_FLAGS=--xla_dump_to=/somewhere, zip up /somewhere and attach it to this bug.

@itk22
Copy link
Author

itk22 commented May 24, 2024

@hawkinsp, here is the HLO dump for the run with 0.4.28

dump.zip

@hawkinsp
Copy link
Collaborator

Hmm. I can't reproduce from the HLO dump. I think we'll need a Python-level reproduction.

@hawkinsp hawkinsp added needs info More information is required to diagnose & prioritize the issue. NVIDIA GPU Issues specific to NVIDIA GPUs labels May 24, 2024
@hawkinsp
Copy link
Collaborator

Actually, never mind, I can reproduce from the HLO.

@hawkinsp hawkinsp removed the needs info More information is required to diagnose & prioritize the issue. label May 24, 2024
@itk22
Copy link
Author

itk22 commented May 29, 2024

Hi @hawkinsp,

I am checking in to see if there are any updates on this issue. Any information would be helpful. Thanks!

@hawkinsp
Copy link
Collaborator

hawkinsp commented May 29, 2024

I filed an internal bug for our XLA compiler folks (b/342589917), and I'm waiting for one of them to take a look.

I have no additional information other than: yes, I can reproduce the problem from the HLO dump.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Jun 4, 2024

A fix for this was merged (openxla/xla@e7bd8ad). The fix should be in today's jaxlib nightly. Please try it out and let me know if the problem is fixed.

@hawkinsp hawkinsp closed this as completed Jun 4, 2024
@itk22
Copy link
Author

itk22 commented Jun 9, 2024

Hi @hawkinsp,
Apologies for the late response. I tried following the installation instructions for JAX nightly but I ended up with the following error:

[INFO 06-10 00:27:06] metatopia: Running version 0.0.1 on the GPU.
E0610 00:27:06.855497   17144 cuda_dnn.cc:535] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E0610 00:27:06.855663   17144 cuda_dnn.cc:539] Memory usage: 7446986752 bytes free, 8585281536 bytes total.
E0610 00:27:06.855951   17144 cuda_dnn.cc:535] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E0610 00:27:06.856039   17144 cuda_dnn.cc:539] Memory usage: 7446986752 bytes free, 8585281536 bytes total.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/igork/projects/metatopia/experiments/topology_optimisation/run.py", line 6, in <module>
    import metatopia as mtp
  File "/home/igork/projects/metatopia/metatopia/__init__.py", line 21, in <module>
    from metatopia import (task_generation, filters, solver, models, utils,
  File "/home/igork/projects/metatopia/metatopia/task_generation/__init__.py", line 1, in <module>
    from .problems import (mbb_beam, generate_random_2D_problem,
  File "/home/igork/projects/metatopia/metatopia/task_generation/problems.py", line 175, in <module>
    fixed_key: jr.PRNGKey = jr.PRNGKey(0)):
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/random.py", line 233, in PRNGKey
    return _return_prng_keys(True, _key('PRNGKey', seed, impl))
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/random.py", line 195, in _key
    return prng.random_seed(seed, impl=impl)
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/prng.py", line 532, in random_seed
    seeds_arr = jnp.asarray(np.int64(seeds))
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3153, in asarray
    return array(a, dtype=dtype, copy=bool(copy), order=order)
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3078, in array
    out_array: Array = lax_internal._convert_element_type(
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 559, in _convert_element_type
    return convert_element_type_p.bind(operand, new_dtype=new_dtype,
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/core.py", line 416, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/core.py", line 909, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

@JohannesAck
Copy link

JohannesAck commented Jun 11, 2024

For what it's worth, I've encountrered the same error in jaxlib 0.4.28 when using jax.experimental.io_callback, but haven't managed to get a Python minimal example to reproduce it yet.

I also tried to install the nightly, and got the exact same error as @itk22 even when only trying to run a simple command:
Doing a clean install fixed the issues with running the nightly.

The nightly then had a nice new traceback, that allowed me to fix the issue.
The actual issue was that I got the shape of the return of jax.experimental.io_callback wrong.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs
Projects
None yet
Development

No branches or pull requests

4 participants