-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stacking flax.nnx.Module
layers with vmap and forwarding with scan result in loss of precision in XLA backend
#27228
Comments
I believe this is a JAX (or rather XLA) issue, but can we make a jax-only reproducer? That’s the best way to tell when to file against jax instead of flax/nnx. Moreover it speeds things up because it’s the first thing we would need to do anyway, but I don’t know flax so I’d be slow at it! |
@cgarciae would it be easy for you to make a jax-only reproducer here? |
@mattjj here's a jax-only version: from functools import partial
import os
import jax
from flax import nnx
# os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
BATCH_SIZE = 2
SEQ_LEN = 2
FEATURES = 16
class AttentionLayer(nnx.Module):
def __init__(self, d_model, rngs):
self.attention = nnx.MultiHeadAttention(
num_heads=8, in_features=d_model, rngs=rngs
)
self.linear1 = nnx.Linear(
in_features=d_model, out_features=d_model, rngs=rngs
)
def __call__(self, x):
x = self.attention(x, decode=False)
x = self.linear1(x)
return x
def foo(x, layer_keys):
@partial(jax.vmap, in_axes=0, out_axes=0)
def create_layer(key):
layer_rngs = nnx.Rngs(key)
return nnx.split(AttentionLayer(FEATURES, layer_rngs))
jax_model = create_layer(layer_keys)
def apply_layers(x, jax_layer):
layer = nnx.merge(*jax_layer)
y = layer(x)
return y, nnx.split(layer)
x, jax_layer = jax.lax.scan(apply_layers, x, jax_model)
return x
def bar(x, layer_keys):
layers = [AttentionLayer(FEATURES, nnx.Rngs(key)) for key in layer_keys]
for layer in layers:
x = layer(x)
return x
key = jax.random.PRNGKey(0)
layer_keys = jax.random.split(key, 2) # 2 layers
x = jax.random.normal(jax.random.PRNGKey(0), (BATCH_SIZE, SEQ_LEN, FEATURES))
foo(x, layer_keys) # errors
bar(x, layer_keys) # works |
I ran the original NNX code on an A100 on Colab and I couldn't reproduce the error. |
This jax-only code causes the same error both on my 3090 and A100 machines. |
@HeavyCrab on your 3090 and A100 machines (i.e. not on CPU or GPU T4) does the computation fail (i.e. raise an exception) with the error message "2025-03-16 21:11:08.907867: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.", or is it instead that the program runs but the values in |
Hi @mattjj, it does not throw an exception (the process was still alive). The full output of this code
is:
The A100 machine:
Additionaly, according to this issue I submitted today, the error disappears after I did this:
The output:
I hope this information would help. |
So can we close this issue then? It seems the results were agreeing to expected precision given the matmul wasn’t in f32. |
@mattjj Sorry, I have a few questions to clarify my understanding. Should we specify float32 precision for all matmul operations in this context? Or should we manually specify float32 precision only when an XLA error occurs for a specific matmul operation? Alternatively, do you think it might be acceptable to ignore the XLA error in this case? |
They're equivalent in exact arithmetic, but not for floating point numbers. For floats, we would expect some deviation depending on the precision in which the computations are performed. Here the matrix multiplies were carried out at a lower precision, namely the A100 default of tensorfloat32, and my understanding is that the deviations observed are within expectations.
It just depends on if you want higher precision. There's a platform-dependent default, and then many other options to choose from.
I don't know how to interpret the XLA error message. I don't see any reason to think it's related to the computed results you observed, since those are within expected tolerances. It may even be XLA-internal logging. I would just ignore it until we have reason to believe otherwise. I opened openxla/xla#23934 to ask XLA folks.
The precision at which to perform matmuls is a decision only you (the programmer) can make. It trades off speed and accuracy, and the right tradeoff depends entirely on what you're trying to do. |
Description
I originally reported this issue in the Flax repo, but the maintainers suggested it might be a JAX issue.
Below is the original bug report with relevant details.
I followed the instructions in the tutorial#scan-over-layers to build a network with multiple layers with
nnx.vmap
and to forward withnnx.scan
. However, doing so reults in loss of precision in XLA backend.This is a minimal example to reproduce the error.
Executing function
foo
results in the following error.Error on the 3090 machine:
Error on the A100 machine:
The function
bar
should be equivalent tofoo
, but works well without any errors.Disabling the xla autotune by
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
prevents the error. So it might be something relating to theautotune
mechanism withvmap
andscan
.CoLab reproducibility
I cannot reproduce it on colab. When running this code on colab with CPU, the result of these two functions are slightly different, while when running this code on colab with GPU T4, the calculated results are identical.
code:
output on CPU:
output on T4:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: