Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stacking flax.nnx.Module layers with vmap and forwarding with scan result in loss of precision in XLA backend #27228

Open
HeavyCrab opened this issue Mar 18, 2025 · 10 comments
Labels
bug Something isn't working

Comments

@HeavyCrab
Copy link

Description

I originally reported this issue in the Flax repo, but the maintainers suggested it might be a JAX issue.
Below is the original bug report with relevant details.

I followed the instructions in the tutorial#scan-over-layers to build a network with multiple layers with nnx.vmap and to forward with nnx.scan. However, doing so reults in loss of precision in XLA backend.

This is a minimal example to reproduce the error.

import os
import jax
from flax import nnx

# os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"

BATCH_SIZE = 2
SEQ_LEN = 2
FEATURES = 16

class AttentionLayer(nnx.Module):
    def __init__(self, d_model, rngs):
        self.attention = nnx.MultiHeadAttention(
            num_heads=8, 
            in_features=d_model,
            rngs=rngs
        )
        self.linear1 = nnx.Linear(in_features=d_model, out_features=d_model, rngs=rngs)
        
    def __call__(self, x):
        x = self.attention(x, decode=False)
        x = self.linear1(x)
        return x

def foo(x, layer_keys):
    @nnx.vmap(in_axes=0, out_axes=0)
    def create_layer(key):
        layer_rngs = nnx.Rngs(key)
        return AttentionLayer(FEATURES, layer_rngs)

    model = create_layer(layer_keys)

    @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
    def apply_layers(layer, x):
        return layer(x)

    return apply_layers(model, x)

def bar(x, layer_keys):
    layers = [AttentionLayer(FEATURES, nnx.Rngs(key)) for key in layer_keys]
    for layer in layers:
        x = layer(x)
    return x


key = jax.random.PRNGKey(0)
layer_keys = jax.random.split(key, 2)   # 2 layers
x = jax.random.normal(jax.random.PRNGKey(0), (BATCH_SIZE, SEQ_LEN, FEATURES))


foo(x, layer_keys)  # errors
bar(x, layer_keys)  # works

Executing function foo results in the following error.
Error on the 3090 machine:

E0316 21:11:08.907722 2675508 buffer_comparator.cc:156] Difference at 6: 8.18504, expected 7.16491
E0316 21:11:08.907797 2675508 buffer_comparator.cc:156] Difference at 7: 10.2058, expected 8.91315
E0316 21:11:08.907806 2675508 buffer_comparator.cc:156] Difference at 8: 8.30671, expected 6.65029
E0316 21:11:08.907811 2675508 buffer_comparator.cc:156] Difference at 9: 9.57833, expected 8.51971
E0316 21:11:08.907816 2675508 buffer_comparator.cc:156] Difference at 11: 12.3298, expected 10.7088
E0316 21:11:08.907827 2675508 buffer_comparator.cc:156] Difference at 15: 6.00732, expected 5.25697
E0316 21:11:08.907832 2675508 buffer_comparator.cc:156] Difference at 22: 8.97186, expected 7.9519
E0316 21:11:08.907838 2675508 buffer_comparator.cc:156] Difference at 24: 9.59525, expected 7.93915
E0316 21:11:08.907845 2675508 buffer_comparator.cc:156] Difference at 27: 13.3396, expected 11.7196
E0316 21:11:08.907852 2675508 buffer_comparator.cc:156] Difference at 38: 8.77498, expected 7.75514
2025-03-16 21:11:08.907867: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.

Error on the A100 machine:

E0317 05:01:29.217242  712317 buffer_comparator.cc:156] Difference at 6: 8.18504, expected 7.16615
E0317 05:01:29.217324  712317 buffer_comparator.cc:156] Difference at 7: 10.2058, expected 8.91452
E0317 05:01:29.217329  712317 buffer_comparator.cc:156] Difference at 8: 8.30671, expected 6.651
E0317 05:01:29.217332  712317 buffer_comparator.cc:156] Difference at 9: 9.57833, expected 8.51998
E0317 05:01:29.217335  712317 buffer_comparator.cc:156] Difference at 11: 12.3298, expected 10.7096
E0317 05:01:29.217339  712317 buffer_comparator.cc:156] Difference at 15: 6.00732, expected 5.25718
E0317 05:01:29.217342  712317 buffer_comparator.cc:156] Difference at 22: 8.97186, expected 7.95259
E0317 05:01:29.217345  712317 buffer_comparator.cc:156] Difference at 24: 9.59525, expected 7.9386
E0317 05:01:29.217348  712317 buffer_comparator.cc:156] Difference at 27: 13.3396, expected 11.7191
E0317 05:01:29.217351  712317 buffer_comparator.cc:156] Difference at 38: 8.77498, expected 7.75621
2025-03-17 05:01:29.217365: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.

The function bar should be equivalent to foo, but works well without any errors.
Disabling the xla autotune by os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0" prevents the error. So it might be something relating to the autotune mechanism with vmap and scan.

CoLab reproducibility

I cannot reproduce it on colab. When running this code on colab with CPU, the result of these two functions are slightly different, while when running this code on colab with GPU T4, the calculated results are identical.
code:

x1 = foo(x, layer_keys)  # errors
x2 = bar(x, layer_keys)  # works
print(x1-x2)

output on CPU:

[[[-1.1920929e-07  1.1920929e-07  1.1920929e-07  2.9802322e-08
    1.1920929e-07  1.1920929e-07  0.0000000e+00  0.0000000e+00
   -1.1920929e-07  2.9802322e-07 -1.4901161e-08 -1.3411045e-07
   -5.9604645e-08 -1.1920929e-07 -1.1920929e-07  1.1920929e-07]
  [ 2.9802322e-08 -5.9604645e-08 -1.1920929e-07  2.0861626e-07
   -2.3841858e-07 -1.1920929e-07 -1.7881393e-07  0.0000000e+00
   -1.1920929e-07  1.7881393e-07  5.9604645e-08  2.9802322e-08
   -4.4703484e-08  0.0000000e+00 -1.1920929e-07  4.7683716e-07]]

 [[ 1.7695129e-07 -6.7055225e-08  8.6612999e-08 -1.2665987e-07
    1.1920929e-07  5.9604645e-08 -7.4505806e-09  1.1548400e-07
    1.4901161e-07  0.0000000e+00  1.1920929e-07  7.4505806e-08
    2.9802322e-08  5.9604645e-08 -1.1920929e-07 -1.7881393e-07]
  [ 2.2351742e-08  2.9802322e-08  9.6857548e-08 -1.0058284e-07
    0.0000000e+00  0.0000000e+00 -1.4901161e-08  1.4901161e-08
    1.4901161e-07  3.7252903e-08  8.9406967e-08 -5.9604645e-08
    0.0000000e+00  1.7881393e-07 -1.7881393e-07 -1.7881393e-07]]]

output on T4:

[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]

System info (python version, jaxlib version, accelerator, etc.)

  • OS Platform: Ubuntu 20.04
  • Flax, jax, jaxlib versions:
flax                      0.10.4                   pypi_0    pypi
jax                       0.5.2                    pypi_0    pypi
jax-cuda12-pjrt           0.5.1                    pypi_0    pypi
jax-cuda12-plugin         0.5.1                    pypi_0    pypi
jaxlib                    0.5.1                    pypi_0    pypi
optax                     0.2.4                    pypi_0    pypi
orbax-checkpoint          0.11.8                   pypi_0    pypi
  • Python version: 3.11.11
  • GPU: RTX 3090, 25.31G and NVIDIA A100 80GB PCIe
@mattjj
Copy link
Collaborator

mattjj commented Mar 18, 2025

I believe this is a JAX (or rather XLA) issue, but can we make a jax-only reproducer? That’s the best way to tell when to file against jax instead of flax/nnx. Moreover it speeds things up because it’s the first thing we would need to do anyway, but I don’t know flax so I’d be slow at it!

@mattjj
Copy link
Collaborator

mattjj commented Mar 18, 2025

@cgarciae would it be easy for you to make a jax-only reproducer here?

@cgarciae
Copy link
Collaborator

cgarciae commented Mar 18, 2025

@mattjj here's a jax-only version:

from functools import partial
import os
import jax
from flax import nnx

# os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"

BATCH_SIZE = 2
SEQ_LEN = 2
FEATURES = 16


class AttentionLayer(nnx.Module):
  def __init__(self, d_model, rngs):
    self.attention = nnx.MultiHeadAttention(
      num_heads=8, in_features=d_model, rngs=rngs
    )
    self.linear1 = nnx.Linear(
      in_features=d_model, out_features=d_model, rngs=rngs
    )

  def __call__(self, x):
    x = self.attention(x, decode=False)
    x = self.linear1(x)
    return x


def foo(x, layer_keys):
  @partial(jax.vmap, in_axes=0, out_axes=0)
  def create_layer(key):
    layer_rngs = nnx.Rngs(key)
    return nnx.split(AttentionLayer(FEATURES, layer_rngs))

  jax_model = create_layer(layer_keys)

  def apply_layers(x, jax_layer):
    layer = nnx.merge(*jax_layer)
    y = layer(x)
    return y, nnx.split(layer)

  x, jax_layer = jax.lax.scan(apply_layers, x, jax_model)
  return x


def bar(x, layer_keys):
  layers = [AttentionLayer(FEATURES, nnx.Rngs(key)) for key in layer_keys]
  for layer in layers:
    x = layer(x)
  return x


key = jax.random.PRNGKey(0)
layer_keys = jax.random.split(key, 2)  # 2 layers
x = jax.random.normal(jax.random.PRNGKey(0), (BATCH_SIZE, SEQ_LEN, FEATURES))


foo(x, layer_keys)  # errors
bar(x, layer_keys)  # works

@cgarciae
Copy link
Collaborator

I ran the original NNX code on an A100 on Colab and I couldn't reproduce the error.

@HeavyCrab
Copy link
Author

I ran the original NNX code on an A100 on Colab and I couldn't reproduce the error.

This jax-only code causes the same error both on my 3090 and A100 machines.

@mattjj
Copy link
Collaborator

mattjj commented Mar 19, 2025

@HeavyCrab on your 3090 and A100 machines (i.e. not on CPU or GPU T4) does the computation fail (i.e. raise an exception) with the error message "2025-03-16 21:11:08.907867: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.", or is it instead that the program runs but the values in print(x1 - x2) seem too large? If the latter, what is the output of print(x1 - x2)?

@HeavyCrab
Copy link
Author

@HeavyCrab on your 3090 and A100 machines (i.e. not on CPU or GPU T4) does the computation fail (i.e. raise an exception) with the error message "2025-03-16 21:11:08.907867: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.", or is it instead that the program runs but the values in print(x1 - x2) seem too large? If the latter, what is the output of print(x1 - x2)?

Hi @mattjj, it does not throw an exception (the process was still alive). The full output of this code

x1 = foo(x, layer_keys)  # errors
x2 = bar(x, layer_keys)  # works
print(x1-x2)

is:
The 3090 machine:

$ python bug.py 
E0319 01:07:20.702570 3435184 buffer_comparator.cc:156] Difference at 6: 8.18504, expected 7.16491
E0319 01:07:20.702631 3435184 buffer_comparator.cc:156] Difference at 7: 10.2058, expected 8.91315
E0319 01:07:20.702635 3435184 buffer_comparator.cc:156] Difference at 8: 8.30671, expected 6.65029
E0319 01:07:20.702637 3435184 buffer_comparator.cc:156] Difference at 9: 9.57833, expected 8.51971
E0319 01:07:20.702639 3435184 buffer_comparator.cc:156] Difference at 11: 12.3298, expected 10.7088
E0319 01:07:20.702642 3435184 buffer_comparator.cc:156] Difference at 15: 6.00732, expected 5.25697
E0319 01:07:20.702645 3435184 buffer_comparator.cc:156] Difference at 22: 8.97186, expected 7.9519
E0319 01:07:20.702647 3435184 buffer_comparator.cc:156] Difference at 24: 9.59525, expected 7.93915
E0319 01:07:20.702649 3435184 buffer_comparator.cc:156] Difference at 27: 13.3396, expected 11.7196
E0319 01:07:20.702651 3435184 buffer_comparator.cc:156] Difference at 38: 8.77498, expected 7.75514
2025-03-19 01:07:20.702659: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
[[[ 7.8338385e-04 -5.6260824e-04  8.8912249e-04 -5.3942204e-06
    2.1111965e-03  1.9997358e-03 -7.1591139e-04 -1.5911460e-04
    1.6542077e-03 -1.2489557e-03  2.6896596e-05  1.2588501e-04
    3.5794079e-04  2.0803213e-03 -1.4548302e-03 -3.0951500e-03]
  [ 8.2293153e-04 -5.9169531e-04  8.3804131e-04  1.8194318e-04
    2.0978451e-03  1.8671751e-03 -9.3191862e-04 -7.3680282e-04
    1.6624928e-03 -1.1649132e-03 -5.1063299e-04 -5.8770180e-05
    6.2733889e-06  1.8830299e-03 -1.5131235e-03 -2.6221275e-03]]

 [[-2.4712086e-04 -1.1175126e-04  9.8645687e-05  5.5686384e-04
    1.4672279e-03  1.3265610e-03 -4.0589273e-04 -1.2467057e-04
    3.4034252e-04  3.6756694e-04 -1.7744303e-04  2.0503998e-05
   -6.9835782e-04  4.8100948e-04 -1.5387535e-03 -9.7024441e-04]
  [ 6.9305301e-05 -2.0000339e-04  2.2349879e-04  6.4449012e-04
    1.0813475e-03  7.3879957e-04 -1.8224120e-04  2.5500357e-04
    5.8168173e-04  2.2760034e-04 -2.7112663e-04  6.6542625e-04
   -7.7524781e-04  8.3667040e-04 -1.0095239e-03 -8.0049038e-04]]]

The A100 machine:

$ python pre.py 
E0319 09:10:15.880207 1577341 buffer_comparator.cc:156] Difference at 6: 8.18504, expected 7.16615
E0319 09:10:15.880277 1577341 buffer_comparator.cc:156] Difference at 7: 10.2058, expected 8.91452
E0319 09:10:15.880282 1577341 buffer_comparator.cc:156] Difference at 8: 8.30671, expected 6.651
E0319 09:10:15.880285 1577341 buffer_comparator.cc:156] Difference at 9: 9.57833, expected 8.51998
E0319 09:10:15.880288 1577341 buffer_comparator.cc:156] Difference at 11: 12.3298, expected 10.7096
E0319 09:10:15.880292 1577341 buffer_comparator.cc:156] Difference at 15: 6.00732, expected 5.25718
E0319 09:10:15.880295 1577341 buffer_comparator.cc:156] Difference at 22: 8.97186, expected 7.95259
E0319 09:10:15.880305 1577341 buffer_comparator.cc:156] Difference at 24: 9.59525, expected 7.9386
E0319 09:10:15.880309 1577341 buffer_comparator.cc:156] Difference at 27: 13.3396, expected 11.7191
E0319 09:10:15.880314 1577341 buffer_comparator.cc:156] Difference at 38: 8.77498, expected 7.75621
2025-03-19 09:10:15.880328: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
[[[ 7.6502562e-04 -7.1328878e-04  6.9761276e-04  4.2319298e-06
    1.9297600e-03  1.6930103e-03 -7.3444843e-04 -7.7733397e-04
    1.7562509e-03 -1.4949441e-03 -7.7141821e-04 -6.8747997e-04
    5.5819750e-05  1.0923147e-03 -1.0086298e-03 -2.5743246e-03]
  [ 5.2657723e-04 -2.9861927e-05 -2.0587444e-04 -2.8568506e-04
    2.2766590e-03  1.1899471e-03  1.6754866e-04 -1.0737777e-04
    1.1651516e-03 -8.2564354e-04 -6.3432753e-04 -8.7797642e-05
    1.6704202e-05  6.9272518e-04 -9.7465515e-04 -1.8434525e-03]]

 [[-2.0973384e-06 -3.8111955e-04  1.8106028e-04  1.3791174e-03
    2.1243095e-03  1.4629960e-03 -5.3527206e-04 -6.9333613e-04
    4.8518181e-04  1.5869737e-05 -5.2502751e-04 -2.9361248e-04
   -1.0331273e-03 -3.0705333e-04 -1.4066100e-03 -7.3063374e-04]
  [ 1.3628602e-04 -2.9978156e-04 -6.0886145e-05  1.1860803e-03
    1.5958548e-03  1.1484623e-03 -3.5095215e-04  6.5729022e-05
    4.9197674e-04  2.8118491e-05 -4.9480796e-04  5.1674247e-04
   -1.2629628e-03  4.8178434e-04 -1.2794733e-03 -1.0079741e-03]]]

Additionaly, according to this issue I submitted today, the error disappears after I did this:

with jax.default_matmul_precision('float32'):

    x1 = foo(x, layer_keys)  # errors
    x2 = bar(x, layer_keys)  # works
    print(x1-x2)

The output:

$ python bug.py 
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]

I hope this information would help.

@mattjj
Copy link
Collaborator

mattjj commented Mar 19, 2025

So can we close this issue then? It seems the results were agreeing to expected precision given the matmul wasn’t in f32.

@HeavyCrab
Copy link
Author

So can we close this issue then? It seems the results were agreeing to expected precision given the matmul wasn’t in f32.

@mattjj Sorry, I have a few questions to clarify my understanding. Should we specify float32 precision for all matmul operations in this context? Or should we manually specify float32 precision only when an XLA error occurs for a specific matmul operation? Alternatively, do you think it might be acceptable to ignore the XLA error in this case?
I noticed that the matmul operation was unable to automatically select the correct precision, causing an XLA error, which seems to be an issue in itself. Additionally, the fact that two mathematically equivalent functions can result in different precision issues (one causing an error and the other not) is also somewhat concerning.
I'm not entirely sure what optimizations vmap and scan perform internally that lead to precision issues with XLA. However, it seems to me that automatic precision selection should ideally be managed by JAX/XLA rather than the user. What are your thoughts on this?

@mattjj
Copy link
Collaborator

mattjj commented Mar 19, 2025

Additionally, the fact that two mathematically equivalent functions can result in different precision issues (one causing an error and the other not) is also somewhat concerning.

They're equivalent in exact arithmetic, but not for floating point numbers. For floats, we would expect some deviation depending on the precision in which the computations are performed. Here the matrix multiplies were carried out at a lower precision, namely the A100 default of tensorfloat32, and my understanding is that the deviations observed are within expectations.

Should we specify float32 precision for all matmul operations in this context?

It just depends on if you want higher precision. There's a platform-dependent default, and then many other options to choose from.

Or should we manually specify float32 precision only when an XLA error occurs for a specific matmul operation? Alternatively, do you think it might be acceptable to ignore the XLA error in this case?

I don't know how to interpret the XLA error message. I don't see any reason to think it's related to the computed results you observed, since those are within expected tolerances. It may even be XLA-internal logging. I would just ignore it until we have reason to believe otherwise. I opened openxla/xla#23934 to ask XLA folks.

I noticed that the matmul operation was unable to automatically select the correct precision, causing an XLA error, which seems to be an issue in itself. [...] However, it seems to me that automatic precision selection should ideally be managed by JAX/XLA rather than the user. What are your thoughts on this?

The precision at which to perform matmuls is a decision only you (the programmer) can make. It trades off speed and accuracy, and the right tradeoff depends entirely on what you're trying to do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants