-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPU backed failure on ARM due to XLA/LLVM (and a potential fix) #5679
Comments
Thanks @girgink ! We will try to fix this internally in tensorflow tree. |
I don't have an environment to test this, but I believe it should be enough to get the JIT to work. This came up in jax-ml/jax#5679 PiperOrigin-RevId: 356578451 Change-Id: I99a2aa0e87739b9edce81074fce1ca5c0bd25115
tensorflow/tensorflow@942f315 links in AArch64 codegen and sets the cmake vars. Let us know if that's sufficient to get JAX to work on your machine. |
Thanks @d0k, I will build again and inform you accordingly |
@girgink Note we haven't bumped the TF version in the JAX WORKSPACE file. You'll either need to do that or point the WORKSPACE file to an up-to-date TF checkout that includes that change. |
Ok! |
@d0k @hawkinsp I confirm that the fix works as expected. I run also tests with All errors are related to call to deprecated create function For the failed tests I suspected that they might be related to available memory. The unit that I use (NVIDIA Jetson AGX Xavier) has unified CPU-GPU memory. I'm not sure if it is the case for JAX, but some frameworks - e.g. TensorFlow - allocate most of GPU memory initially to speed-up processing, but because Xavier has unified memory this results in very low CPU memory. In fact, running the tests in parallel (e.g. Please find below the test results: Click to expand
|
The warning about The segmentation faults are more concerning. Are you running these tests on CPU or GPU? By default, they are probably GPU tests. You can prevent GPU usage by setting |
I tried running the tests that failed with a segfault under ARM emulation and they don't fail at head for me. I'm going to guess that they are fixed at head. If you can still reproduce these segfaults at head, please reopen the bug! |
Hiya, Looks like this bug is back in the latest JAX. On a fresh MacBook Pro M1 (native, not rosetta), I installed numpy and scipy, then I'm getting the crash about
Since this is the only github issue that comes up related to JAX referencing this |
|
The Sadly I don't have an M1 mac to verify, but it's very likely that you're hitting the BUILD file problem that tensorflow/tensorflow@cd76ed3 is supposed to fix. Does it still occur after that change? The proper fix for that is blocked by TF supporting older versions of bazel, but the workaround should be enough so it recognizes your machine as arm64. |
After many hours, I'm reasonably confident that neither tensorflow/tensorflow@cd76ed3 nor tensorflow/tensorflow@e24a3b5 fix the problem. How can I hack the tensorflow codebase to force arm64? Are you sure that My build process was:
diff --git a/WORKSPACE b/WORKSPACE
index f4d50e5c..83b56c7c 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -1,24 +1,24 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
-# To update TensorFlow to a new revision,
-# a) update URL and strip_prefix to the new git commit hash
-# b) get the sha256 hash of the commit by running:
-# curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
-# and update the sha256 with the result.
-http_archive(
- name = "org_tensorflow",
- sha256 = "b2c8b912e7be71306ab6fee063fb4ec1dfbe7158e7e8469d674f8af6583434d4",
- strip_prefix = "tensorflow-e98b052c08e5d1e7906ac2f6caf95c51a1e04985",
- urls = [
- "https://github.com/tensorflow/tensorflow/archive/e98b052c08e5d1e7906ac2f6caf95c51a1e04985.tar.gz",
- ],
-)
+# # To update TensorFlow to a new revision,
+# # a) update URL and strip_prefix to the new git commit hash
+# # b) get the sha256 hash of the commit by running:
+# # curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
+# # and update the sha256 with the result.
+# http_archive(
+# name = "org_tensorflow",
+# sha256 = "b2c8b912e7be71306ab6fee063fb4ec1dfbe7158e7e8469d674f8af6583434d4",
+# strip_prefix = "tensorflow-e98b052c08e5d1e7906ac2f6caf95c51a1e04985",
+# urls = [
+# "https://github.com/tensorflow/tensorflow/archive/e98b052c08e5d1e7906ac2f6caf95c51a1e04985.tar.gz",
+# ],
+# )
# For development, one can use a local TF repository instead.
-# local_repository(
-# name = "org_tensorflow",
-# path = "tensorflow",
-# )
+local_repository(
+ name = "org_tensorflow",
+ path = "tensorflow",
+)
load("//third_party/pocketfft:workspace.bzl", pocketfft = "repo")
pocketfft() Lastly:
Then
I tried the above steps with commit tensorflow/tensorflow@e24a3b5 as well, with no luck. I'm pretty sure that tensorflow commit tensorflow/tensorflow@8cc3ffa was working on an M1, because about 4 months ago I built Jax on my old M1 Air laptop for Python 3.8, and the Jax repo at that time was using that tensorflow commit: https://github.com/shawwn/jax/blob/m1/WORKSPACE#L11 So whatever the problem is, it happened somewhere between tensorflow/tensorflow@8cc3ffa and HEAD. ... which is a gargantuan diff: tensorflow/tensorflow@8cc3ffa...master |
I've confirmed that If anyone suffering this problem (like @tigerneil) wants a temporary fix -- assuming you trust me -- then you can install my
At that point, you should be able to run
Note that this is actually jaxlib version (I'll post about any problems I run into.) Once I get some time, I'll try doing a |
Happy to report that the latest jaxlib seems to build and run fine on M1! I built from jax commit d569440 and tensorflow commit tensorflow/tensorflow@071a34e My build process was:
patch WORKSPACE: diff --git a/WORKSPACE b/WORKSPACE
index 70f310f9..dc6c271c 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -1,24 +1,24 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
-# To update TensorFlow to a new revision,
-# a) update URL and strip_prefix to the new git commit hash
-# b) get the sha256 hash of the commit by running:
-# curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
-# and update the sha256 with the result.
-http_archive(
- name = "org_tensorflow",
- sha256 = "7756b69b4a2a036ad7a4a8478f8bd7d69d0026d9c8c7fe8e8f1ae6205e978719",
- strip_prefix = "tensorflow-968a1751ef6ccadc30ac6bd0f0be5056ac0e9288",
- urls = [
- "https://github.com/tensorflow/tensorflow/archive/968a1751ef6ccadc30ac6bd0f0be5056ac0e9288.tar.gz",
- ],
-)
+# # To update TensorFlow to a new revision,
+# # a) update URL and strip_prefix to the new git commit hash
+# # b) get the sha256 hash of the commit by running:
+# # curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
+# # and update the sha256 with the result.
+# http_archive(
+# name = "org_tensorflow",
+# sha256 = "7756b69b4a2a036ad7a4a8478f8bd7d69d0026d9c8c7fe8e8f1ae6205e978719",
+# strip_prefix = "tensorflow-968a1751ef6ccadc30ac6bd0f0be5056ac0e9288",
+# urls = [
+# "https://github.com/tensorflow/tensorflow/archive/968a1751ef6ccadc30ac6bd0f0be5056ac0e9288.tar.gz",
+# ],
+# )
# For development, one can use a local TF repository instead.
-# local_repository(
-# name = "org_tensorflow",
-# path = "tensorflow",
-# )
+local_repository(
+ name = "org_tensorflow",
+ path = "tensorflow",
+)
load("//third_party/pocketfft:workspace.bzl", pocketfft = "repo")
pocketfft() Build jaxlib with caching:
Install jaxlib:
Use the local jax repo:
Install scipy using the workaround here: scipy/scipy#13409
Test it out: $ python3
Python 3.9.10 (main, Jan 15 2022, 11:40:36)
[Clang 13.0.0 (clang-1300.0.29.3)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
~/ml/jax/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
jax.__version__
>>> jax.__version__
'0.3.1'
>>> import jaxlib
>>> jaxlib.__version__
'0.3.1'
>>> import jax.numpy as jnp; print(jnp.zeros((1,2)))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[[0. 0.]] Thanks for fixing this! |
I don't have an environment to test this, but I believe it should be enough to get the JIT to work. This came up in jax-ml/jax#5679 PiperOrigin-RevId: 356578451 Change-Id: I99a2aa0e87739b9edce81074fce1ca5c0bd25115
Hi,
JAX CPU backend fails on ARM architecture (e.g. NVIDIA Jetson AGX, ARMv8.2) with the following errors:
JAX was reported to work on similar architectures (e.g. NVIDIA Jetson TX2), but I think they only used GPU backend, which works fine.
We have identified the problem as missing AArch64 statements in XLA and LLVM build files. The following seem to fix the problem:
cpu_compiler:deps
inorg_tensorflow/tensorflow/compiler/xla/service/cpu/BUILD
llvm_host_triple
inorg_tensorflow/third_party/llvm/llvm.autogenerated.BUILD
(this is required due to default fall-back architecture that is set as X86_64, which is indicated to be fixed)linux_aarch64
target tollvm_all_cmake_vars
inorg_tensorflow/third_party/llvm/llvm.bzl
With these changes the result is as follows:
The text was updated successfully, but these errors were encountered: