-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong results on CPU since 0.4.32 #23590
Comments
Another CPU change in v0.4.32 is that some of the LAPACK wrappers in |
Another thing to try: does There was a major change to the implementation of the CPU backend. Notably, we'll use more concurrency on CPU. If that fixes things, please share a reproduction. |
We actually just yanked 0.4.32 because of a TPU problem, but if you can get a reproducer it'd be great to look into this. |
Thanks both for the suggestions. The kernel in question doesn't use linear algebra, and the suggested XLA flags didn't make a difference. I managed to dig up a reproducer. You can run this: $ git clone git@github.com:dionhaefner/pyhpc-benchmarks.git
$ cd pyhpc-benchmarks
$ python run.py benchmarks/isoneutral_mixing/ --device cpu -b jax -b numpy -s 4096 With JAX 0.4.32 this prints
but not for older JAX versions. This is the kernel that's being run, with random input arrays: |
Even simpler, you can run this script in the from isoneutral_numpy import run as run_numpy
from isoneutral_jax import run as run_jax
import numpy as np
import jax
jax.config.update("jax_enable_x64", True)
def generate_inputs(size):
import math
np.random.seed(17)
shape = (
math.ceil(2 * size ** (1 / 3)),
math.ceil(2 * size ** (1 / 3)),
math.ceil(0.25 * size ** (1 / 3)),
)
# masks
maskT, maskU, maskV, maskW = (
(np.random.rand(*shape) < 0.8).astype("float64") for _ in range(4)
)
# 1d arrays
dxt, dxu = (np.random.randn(shape[0]) for _ in range(2))
dyt, dyu = (np.random.randn(shape[1]) for _ in range(2))
dzt, dzw, zt = (np.random.randn(shape[2]) for _ in range(3))
cost, cosu = (np.random.randn(shape[1]) for _ in range(2))
# 3d arrays
K_iso, K_11, K_22, K_33 = (np.random.randn(*shape) for _ in range(4))
# 4d arrays
salt, temp = (np.random.randn(*shape, 3) for _ in range(2))
# 5d arrays
Ai_ez, Ai_nz, Ai_bx, Ai_by = (np.zeros((*shape, 2, 2)) for _ in range(4))
return (
maskT,
maskU,
maskV,
maskW,
dxt,
dxu,
dyt,
dyu,
dzt,
dzw,
cost,
cosu,
salt,
temp,
zt,
K_iso,
K_11,
K_22,
K_33,
Ai_ez,
Ai_nz,
Ai_bx,
Ai_by,
)
testinputs = generate_inputs(1000)
def test_run():
inputs_np = [x.copy() for x in testinputs]
inputs_jax = [jax.numpy.asarray(x) for x in testinputs]
out_numpy = run_numpy(*inputs_np)
out_jax = run_jax(*inputs_jax)
for x_np, x_jax in zip(out_numpy, out_jax):
np.testing.assert_allclose(x_np, x_jax)
if __name__ == "__main__":
test_run() |
Thank you for the code! I was able to reproduce the issue. Our team will look into this soon. |
jax-ml/jax#23590 PiperOrigin-RevId: 674300479
jax-ml/jax#23590 PiperOrigin-RevId: 674300479
jax-ml/jax#23590 PiperOrigin-RevId: 674300479
jax-ml/jax#23590 PiperOrigin-RevId: 674332130
jax-ml/jax#23590 PiperOrigin-RevId: 674332130
This is because of my f64 Tanh approximation commit: openxla/xla@ae96f6e I've temporarily disabled the feature in openxla/xla@8fcf359. The change is now in JAX nightly 20240914 and newer:
I've verified that the script doesn't get numerical errors anymore with the new nightly wheel. The original benchmark also passed.
I'll also check that both the script and the benchmark run fine before re-enabling the fast f64 Tanh approximation. |
I'm going to cherry-pick this change into the jax v0.4.33 release that I'm about to make. |
We just released JAX 0.4.33, which includes the fix for this issue. |
Description
I'm seeing test failures in Veros when bumping JAX to 0.4.32.
There appear to be significant deviations from the expected results (which are computed by a Fortran reference). All tests are executed on CPU and pass for previous versions of JAX.
I've tried setting
jax.config.update('jax_cpu_enable_async_dispatch', False)
(since that's the only thing I saw in the changelog that I thought may be related) but it made no difference.Anything else I could try on my end? Asking for wild guesses here, because going through the motions to isolate the problem is really nontrivial (these are complicated kernels consisting of 1000s of SLOC).
System info (python version, jaxlib version, accelerator, etc.)
Python 3.12, JAX 0.4.32, CPU
The text was updated successfully, but these errors were encountered: