You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I followed the instructions in the tutorial#scan-over-layers to build a network with multiple layers with nnx.vmap and to forward with nnx.scan. However, doing so reults in loss of precision in XLA backend.
import os
import jax
from flax import nnx
# os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
BATCH_SIZE = 2
SEQ_LEN = 2
FEATURES = 16
class AttentionLayer(nnx.Module):
def __init__(self, d_model, rngs):
self.attention = nnx.MultiHeadAttention(
num_heads=8,
in_features=d_model,
rngs=rngs
)
self.linear1 = nnx.Linear(in_features=d_model, out_features=d_model, rngs=rngs)
def __call__(self, x):
x = self.attention(x, decode=False)
x = self.linear1(x)
return x
def foo(x, layer_keys):
@nnx.vmap(in_axes=0, out_axes=0)
def create_layer(key):
layer_rngs = nnx.Rngs(key)
return AttentionLayer(FEATURES, layer_rngs)
model = create_layer(layer_keys)
@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def apply_layers(layer, x):
return layer(x)
return apply_layers(model, x)
def bar(x, layer_keys):
layers = [AttentionLayer(FEATURES, nnx.Rngs(key)) for key in layer_keys]
for layer in layers:
x = layer(x)
return x
key = jax.random.PRNGKey(0)
layer_keys = jax.random.split(key, 2) # 2 layers
x = jax.random.normal(jax.random.PRNGKey(0), (BATCH_SIZE, SEQ_LEN, FEATURES))
foo(x, layer_keys) # errors
bar(x, layer_keys) # works
Executing function foo results in the following error.
Error on the 3090 machine:
E0316 21:11:08.907722 2675508 buffer_comparator.cc:156] Difference at 6: 8.18504, expected 7.16491
E0316 21:11:08.907797 2675508 buffer_comparator.cc:156] Difference at 7: 10.2058, expected 8.91315
E0316 21:11:08.907806 2675508 buffer_comparator.cc:156] Difference at 8: 8.30671, expected 6.65029
E0316 21:11:08.907811 2675508 buffer_comparator.cc:156] Difference at 9: 9.57833, expected 8.51971
E0316 21:11:08.907816 2675508 buffer_comparator.cc:156] Difference at 11: 12.3298, expected 10.7088
E0316 21:11:08.907827 2675508 buffer_comparator.cc:156] Difference at 15: 6.00732, expected 5.25697
E0316 21:11:08.907832 2675508 buffer_comparator.cc:156] Difference at 22: 8.97186, expected 7.9519
E0316 21:11:08.907838 2675508 buffer_comparator.cc:156] Difference at 24: 9.59525, expected 7.93915
E0316 21:11:08.907845 2675508 buffer_comparator.cc:156] Difference at 27: 13.3396, expected 11.7196
E0316 21:11:08.907852 2675508 buffer_comparator.cc:156] Difference at 38: 8.77498, expected 7.75514
2025-03-16 21:11:08.907867: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Error on the A100 machine:
E0317 05:01:29.217242 712317 buffer_comparator.cc:156] Difference at 6: 8.18504, expected 7.16615
E0317 05:01:29.217324 712317 buffer_comparator.cc:156] Difference at 7: 10.2058, expected 8.91452
E0317 05:01:29.217329 712317 buffer_comparator.cc:156] Difference at 8: 8.30671, expected 6.651
E0317 05:01:29.217332 712317 buffer_comparator.cc:156] Difference at 9: 9.57833, expected 8.51998
E0317 05:01:29.217335 712317 buffer_comparator.cc:156] Difference at 11: 12.3298, expected 10.7096
E0317 05:01:29.217339 712317 buffer_comparator.cc:156] Difference at 15: 6.00732, expected 5.25718
E0317 05:01:29.217342 712317 buffer_comparator.cc:156] Difference at 22: 8.97186, expected 7.95259
E0317 05:01:29.217345 712317 buffer_comparator.cc:156] Difference at 24: 9.59525, expected 7.9386
E0317 05:01:29.217348 712317 buffer_comparator.cc:156] Difference at 27: 13.3396, expected 11.7191
E0317 05:01:29.217351 712317 buffer_comparator.cc:156] Difference at 38: 8.77498, expected 7.75621
2025-03-17 05:01:29.217365: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
The function bar should be equivalent to foo, but works well without any errors.
Disabling the xla autotune by os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0" prevents the error. So it might be something relating to the autotune mechanism with vmap and scan.
CoLab reproducibility
I cannot reproduce it on colab. When running this code on colab with CPU, the result of these two functions are slightly different, while when running this code on colab with GPU T4, the calculated results are identical.
code:
Hi @HeavyCrab, this sounds like a JAX issue, can you please report this to the JAX repo? Doesn't seem there's a ton we can do from the Flax side here. Sorry for the inconvenience.
I followed the instructions in the tutorial#scan-over-layers to build a network with multiple layers with
nnx.vmap
and to forward withnnx.scan
. However, doing so reults in loss of precision in XLA backend.System information
Problem you have encountered:
This is a minimal example to reproduce the error.
Executing function
foo
results in the following error.Error on the 3090 machine:
Error on the A100 machine:
The function
bar
should be equivalent tofoo
, but works well without any errors.Disabling the xla autotune by
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
prevents the error. So it might be something relating to theautotune
mechanism withvmap
andscan
.CoLab reproducibility
I cannot reproduce it on colab. When running this code on colab with CPU, the result of these two functions are slightly different, while when running this code on colab with GPU T4, the calculated results are identical.
code:
output on CPU:
output on T4:
The text was updated successfully, but these errors were encountered: