Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

worse precision when using vmap #15362

Closed
Linusnie opened this issue Apr 2, 2023 · 3 comments
Closed

worse precision when using vmap #15362

Linusnie opened this issue Apr 2, 2023 · 3 comments

Comments

@Linusnie
Copy link

Linusnie commented Apr 2, 2023

Description

I was following the network training tutorial and noticed a case where using vmap on the predict function for a 4-layer network with input_dim=output_dim=1 results in significantly worse numerical accuracy:

import jax.numpy as jnp
from jax import vmap
from jax import random

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [1, 512, 512, 1]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

x = jnp.linspace(0, 1e-2, 100)[:, None]
y_vmap = vmap(predict, in_axes=(None, 0))(params, x)
y = jnp.array([predict(params, xx) for xx in x])

fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(x, y, '-o', label='no vmap')
ax.plot(x, y_vmap, '-o', label='vmap')
ax.legend()

Figure 32

is this behavior to be expected or is it a bug? I saw this related question #8712 but the answer there indicates that both approaches are comparable, while in this case using vmap obviously worse in terms of precision.

With float64 enabled both outputs are the same even when setting x=jnp.linspace(0, 1e-30, 100)[:, None]

What jax/jaxlib version are you using?

jax v0.4.8, jaxlib v0.4.7

Which accelerator(s) are you using?

GPU

Additional system info

ubuntu 22.04

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA RTX A300... On | 00000000:01:00.0 On | Off |
| N/A 52C P5 19W / 119W | 9739MiB / 12288MiB | 39% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 2219 G /usr/lib/xorg/Xorg 105MiB |
| 0 N/A N/A 132485 C ...tions/venv3.10/bin/python 9630MiB |
+-----------------------------------------------------------------------------+

@Linusnie Linusnie added the bug Something isn't working label Apr 2, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 2, 2023

Thanks for the report - this appears to be backend-dependent behavior: I cannot reproduce the issue on CPU or on a T4 GPU. Perhaps it's somehow specific to the A300?

@shoyer
Copy link
Collaborator

shoyer commented Apr 4, 2023

vmap is converting your matrix-vector products with jnp.dot into matrix-matrix products. By default, matrix-matrix multiplication in JAX uses tensorfloat32 precision on Nvidia Ampere GPUs.

To fix this, try explicitly setting higher precision, i.e., jnp.dot(w, activations, precision='float32').

Note that you may have to upgrade to a very recent release of JAX to get this working, since this was fixed only recently: #14022

@Linusnie
Copy link
Author

Linusnie commented Apr 4, 2023

@shoyer I can confirm that setting precision='float32' in jnp.dot fixes the issue, thanks for the explanation!

@shoyer shoyer removed the bug Something isn't working label Apr 5, 2023
@shoyer shoyer closed this as completed Apr 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants