You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It is somewhat unintuitive that the default matmul precision is bfloat16 on TPU, especially for users coming from PyTorch/GPU where the default precision is float32. Information regarding the default matrix multiplication precision on TPUs is extremely difficult to find. There is a short section on the README.md within the cloud TPU Colab folder of the JAX repo: https://github.com/google/jax/tree/main/cloud_tpu_colabs#bfloat16-dtype However, this is somewhat unclear, as it references 'MXUs' without any explanation of what this abbreviation means, and only highlights how the default precision can be changed manually on a op-by-op basis by setting precision=jax.lax.Precision.XXX. This gives the impression that in order to change the TPU precision to float32, one must insert the key-word argument precision=jax.lax.Precision.HIGHEST for every jax.numpy operation in one's script.
It is difficult to find how the default precision can be changed. Performing matmul operations in the default bfloat16 precision can lead to undesirable results. At Hugging Face, we're constantly running into problems with the default fast-speed low precision TPU default, as shown here for example: Diverging PT-Flax Wav2Vec2 Hidden-States huggingface/transformers#15754
In the case of changing the default matmul precision, the docs do make mention to the default matmul precision context manager: https://jax.readthedocs.io/en/latest/_autosummary/jax.default_matmul_precision.html However, they do not explicitly state how one can use this context manager to change the default matmul precision (for instance with an example). It's hard to know from the docs that you have to write your code under the context manager as follows:
withjax.default_matmul_precision('float32'): # or 'bfloat16' for lowest
... =foo(...)
The docs also brush over three additional methods for changing the default matmul precision, highlighted brilliantly in this PR: #6143 (comment) These three methods require no change to one's actual script, just the inclusion of a shell/command line flag or a JAX config change, and are arguably much easier to use and less obtrusive.
It would be great if the default matmul precisions for CPU/GPU/TPU were documented, along with what bfloat16, tensorfloat16, float32 precision actually mean for matmul precision in terms of number of passes. It would also be super helpful if all four methods for manipulating the default precision were added to the docs with short examples on how to use them, as done in the aforementioned PR.
The text was updated successfully, but these errors were encountered:
It is somewhat unintuitive that the default matmul precision is bfloat16 on TPU, especially for users coming from PyTorch/GPU where the default precision is float32. Information regarding the default matrix multiplication precision on TPUs is extremely difficult to find. There is a short section on the README.md within the cloud TPU Colab folder of the JAX repo: https://github.com/google/jax/tree/main/cloud_tpu_colabs#bfloat16-dtype However, this is somewhat unclear, as it references 'MXUs' without any explanation of what this abbreviation means, and only highlights how the default precision can be changed manually on a op-by-op basis by setting
precision=jax.lax.Precision.XXX
. This gives the impression that in order to change the TPU precision to float32, one must insert the key-word argumentprecision=jax.lax.Precision.HIGHEST
for everyjax.numpy
operation in one's script.It is difficult to find how the default precision can be changed. Performing matmul operations in the default bfloat16 precision can lead to undesirable results. At Hugging Face, we're constantly running into problems with the default fast-speed low precision TPU default, as shown here for example: Diverging PT-Flax Wav2Vec2 Hidden-States huggingface/transformers#15754
In the case of changing the default matmul precision, the docs do make mention to the default matmul precision context manager: https://jax.readthedocs.io/en/latest/_autosummary/jax.default_matmul_precision.html However, they do not explicitly state how one can use this context manager to change the default matmul precision (for instance with an example). It's hard to know from the docs that you have to write your code under the context manager as follows:
The docs also brush over three additional methods for changing the default matmul precision, highlighted brilliantly in this PR: #6143 (comment) These three methods require no change to one's actual script, just the inclusion of a shell/command line flag or a JAX config change, and are arguably much easier to use and less obtrusive.
It would be great if the default matmul precisions for CPU/GPU/TPU were documented, along with what bfloat16, tensorfloat16, float32 precision actually mean for matmul precision in terms of number of passes. It would also be super helpful if all four methods for manipulating the default precision were added to the docs with short examples on how to use them, as done in the aforementioned PR.
The text was updated successfully, but these errors were encountered: