-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
worse precision when using vmap #15362
Comments
Thanks for the report - this appears to be backend-dependent behavior: I cannot reproduce the issue on CPU or on a T4 GPU. Perhaps it's somehow specific to the A300? |
To fix this, try explicitly setting higher precision, i.e., Note that you may have to upgrade to a very recent release of JAX to get this working, since this was fixed only recently: #14022 |
@shoyer I can confirm that setting |
Description
I was following the network training tutorial and noticed a case where using
vmap
on thepredict
function for a 4-layer network with input_dim=output_dim=1 results in significantly worse numerical accuracy:is this behavior to be expected or is it a bug? I saw this related question #8712 but the answer there indicates that both approaches are comparable, while in this case using
vmap
obviously worse in terms of precision.With float64 enabled both outputs are the same even when setting
x=jnp.linspace(0, 1e-30, 100)[:, None]
What jax/jaxlib version are you using?
jax v0.4.8, jaxlib v0.4.7
Which accelerator(s) are you using?
GPU
Additional system info
ubuntu 22.04
NVIDIA GPU info
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA RTX A300... On | 00000000:01:00.0 On | Off |
| N/A 52C P5 19W / 119W | 9739MiB / 12288MiB | 39% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 2219 G /usr/lib/xorg/Xorg 105MiB |
| 0 N/A N/A 132485 C ...tions/venv3.10/bin/python 9630MiB |
+-----------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered: