Skip to content

Commit

Permalink
[Pallas:MGPU] Fix an overly strict precision requirement in tests
Browse files Browse the repository at this point in the history
They started failing after we allowed LLVM to perform contractions of
adds and muls, but the difference is tiny.

PiperOrigin-RevId: 701961845
  • Loading branch information
apaszke authored and Google-ML-Automation committed Dec 2, 2024
1 parent 5d5b06c commit aff7714
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,7 @@ def layer_norm_np(x):
jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32)
* input_factor
)
# TODO(cperivol): find out why in this particular case we have a small-ish error.
rtol = 1e-07 if input_factor > 10 else 5e-5
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=rtol)
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=5e-5)

def test_print(self):
@functools.partial(
Expand Down

0 comments on commit aff7714

Please sign in to comment.