Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stacking layers with vmap and forwarding with scan result in loss of precision in XLA backend #4629

Open
HeavyCrab opened this issue Mar 16, 2025 · 2 comments

Comments

@HeavyCrab
Copy link

I followed the instructions in the tutorial#scan-over-layers to build a network with multiple layers with nnx.vmap and to forward with nnx.scan. However, doing so reults in loss of precision in XLA backend.

System information

  • OS Platform: Ubuntu 20.04
  • Flax, jax, jaxlib versions:
flax                      0.10.4                   pypi_0    pypi
jax                       0.5.2                    pypi_0    pypi
jax-cuda12-pjrt           0.5.1                    pypi_0    pypi
jax-cuda12-plugin         0.5.1                    pypi_0    pypi
jaxlib                    0.5.1                    pypi_0    pypi
optax                     0.2.4                    pypi_0    pypi
orbax-checkpoint          0.11.8                   pypi_0    pypi
  • Python version: 3.11.11
  • GPU: RTX 3090, 25.31G and NVIDIA A100 80GB PCIe

Problem you have encountered:

This is a minimal example to reproduce the error.

import os
import jax
from flax import nnx

# os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"

BATCH_SIZE = 2
SEQ_LEN = 2
FEATURES = 16

class AttentionLayer(nnx.Module):
    def __init__(self, d_model, rngs):
        self.attention = nnx.MultiHeadAttention(
            num_heads=8, 
            in_features=d_model,
            rngs=rngs
        )
        self.linear1 = nnx.Linear(in_features=d_model, out_features=d_model, rngs=rngs)
        
    def __call__(self, x):
        x = self.attention(x, decode=False)
        x = self.linear1(x)
        return x

def foo(x, layer_keys):
    @nnx.vmap(in_axes=0, out_axes=0)
    def create_layer(key):
        layer_rngs = nnx.Rngs(key)
        return AttentionLayer(FEATURES, layer_rngs)

    model = create_layer(layer_keys)

    @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
    def apply_layers(layer, x):
        return layer(x)

    return apply_layers(model, x)

def bar(x, layer_keys):
    layers = [AttentionLayer(FEATURES, nnx.Rngs(key)) for key in layer_keys]
    for layer in layers:
        x = layer(x)
    return x


key = jax.random.PRNGKey(0)
layer_keys = jax.random.split(key, 2)   # 2 layers
x = jax.random.normal(jax.random.PRNGKey(0), (BATCH_SIZE, SEQ_LEN, FEATURES))


foo(x, layer_keys)  # errors
bar(x, layer_keys)  # works

Executing function foo results in the following error.
Error on the 3090 machine:

E0316 21:11:08.907722 2675508 buffer_comparator.cc:156] Difference at 6: 8.18504, expected 7.16491
E0316 21:11:08.907797 2675508 buffer_comparator.cc:156] Difference at 7: 10.2058, expected 8.91315
E0316 21:11:08.907806 2675508 buffer_comparator.cc:156] Difference at 8: 8.30671, expected 6.65029
E0316 21:11:08.907811 2675508 buffer_comparator.cc:156] Difference at 9: 9.57833, expected 8.51971
E0316 21:11:08.907816 2675508 buffer_comparator.cc:156] Difference at 11: 12.3298, expected 10.7088
E0316 21:11:08.907827 2675508 buffer_comparator.cc:156] Difference at 15: 6.00732, expected 5.25697
E0316 21:11:08.907832 2675508 buffer_comparator.cc:156] Difference at 22: 8.97186, expected 7.9519
E0316 21:11:08.907838 2675508 buffer_comparator.cc:156] Difference at 24: 9.59525, expected 7.93915
E0316 21:11:08.907845 2675508 buffer_comparator.cc:156] Difference at 27: 13.3396, expected 11.7196
E0316 21:11:08.907852 2675508 buffer_comparator.cc:156] Difference at 38: 8.77498, expected 7.75514
2025-03-16 21:11:08.907867: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.

Error on the A100 machine:

E0317 05:01:29.217242  712317 buffer_comparator.cc:156] Difference at 6: 8.18504, expected 7.16615
E0317 05:01:29.217324  712317 buffer_comparator.cc:156] Difference at 7: 10.2058, expected 8.91452
E0317 05:01:29.217329  712317 buffer_comparator.cc:156] Difference at 8: 8.30671, expected 6.651
E0317 05:01:29.217332  712317 buffer_comparator.cc:156] Difference at 9: 9.57833, expected 8.51998
E0317 05:01:29.217335  712317 buffer_comparator.cc:156] Difference at 11: 12.3298, expected 10.7096
E0317 05:01:29.217339  712317 buffer_comparator.cc:156] Difference at 15: 6.00732, expected 5.25718
E0317 05:01:29.217342  712317 buffer_comparator.cc:156] Difference at 22: 8.97186, expected 7.95259
E0317 05:01:29.217345  712317 buffer_comparator.cc:156] Difference at 24: 9.59525, expected 7.9386
E0317 05:01:29.217348  712317 buffer_comparator.cc:156] Difference at 27: 13.3396, expected 11.7191
E0317 05:01:29.217351  712317 buffer_comparator.cc:156] Difference at 38: 8.77498, expected 7.75621
2025-03-17 05:01:29.217365: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.

The function bar should be equivalent to foo, but works well without any errors.
Disabling the xla autotune by os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0" prevents the error. So it might be something relating to the autotune mechanism with vmap and scan.

CoLab reproducibility

I cannot reproduce it on colab. When running this code on colab with CPU, the result of these two functions are slightly different, while when running this code on colab with GPU T4, the calculated results are identical.
code:

x1 = foo(x, layer_keys)  # errors
x2 = bar(x, layer_keys)  # works
print(x1-x2)

output on CPU:

[[[-1.1920929e-07  1.1920929e-07  1.1920929e-07  2.9802322e-08
    1.1920929e-07  1.1920929e-07  0.0000000e+00  0.0000000e+00
   -1.1920929e-07  2.9802322e-07 -1.4901161e-08 -1.3411045e-07
   -5.9604645e-08 -1.1920929e-07 -1.1920929e-07  1.1920929e-07]
  [ 2.9802322e-08 -5.9604645e-08 -1.1920929e-07  2.0861626e-07
   -2.3841858e-07 -1.1920929e-07 -1.7881393e-07  0.0000000e+00
   -1.1920929e-07  1.7881393e-07  5.9604645e-08  2.9802322e-08
   -4.4703484e-08  0.0000000e+00 -1.1920929e-07  4.7683716e-07]]

 [[ 1.7695129e-07 -6.7055225e-08  8.6612999e-08 -1.2665987e-07
    1.1920929e-07  5.9604645e-08 -7.4505806e-09  1.1548400e-07
    1.4901161e-07  0.0000000e+00  1.1920929e-07  7.4505806e-08
    2.9802322e-08  5.9604645e-08 -1.1920929e-07 -1.7881393e-07]
  [ 2.2351742e-08  2.9802322e-08  9.6857548e-08 -1.0058284e-07
    0.0000000e+00  0.0000000e+00 -1.4901161e-08  1.4901161e-08
    1.4901161e-07  3.7252903e-08  8.9406967e-08 -5.9604645e-08
    0.0000000e+00  1.7881393e-07 -1.7881393e-07 -1.7881393e-07]]]

output on T4:

[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]
@cgarciae
Copy link
Collaborator

Hi @HeavyCrab, this sounds like a JAX issue, can you please report this to the JAX repo? Doesn't seem there's a ton we can do from the Flax side here. Sorry for the inconvenience.

@HeavyCrab
Copy link
Author

@cgarciae OK, I have reported this to the JAX repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants