Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run jaxlib in docker on Mac arm64 #13608

Closed
nikitaqwerty opened this issue Dec 11, 2022 · 5 comments
Closed

Run jaxlib in docker on Mac arm64 #13608

nikitaqwerty opened this issue Dec 11, 2022 · 5 comments
Labels
bug Something isn't working

Comments

@nikitaqwerty
Copy link

Description

I'm trying to run jaxlib on Mac arm64 chip through docker container.
Trying to install it through pip I'm getting error
ERROR: No matching distribution found for jaxlib>=0.3.18
I've tried also to run this container using linux/amd64 image. Jaxlib gets installed well through pipl but after running the code I'm getting following error:
RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.

So what is the correct way to run jaxlib on mac arm64 hosted docker container?

What jax/jaxlib version are you using?

jaxlib 0.3.25

Which accelerator(s) are you using?

CPY

Additional system info

Mac Apple M1 Pro

NVIDIA GPU info

No response

@nikitaqwerty nikitaqwerty added the bug Something isn't working label Dec 11, 2022
@zhangqiaorjc
Copy link
Collaborator

I don't run hosted docker container, but on my M1, the following version upgrade seems fine

Python 3.10.1 (main, Mar  7 2022, 13:38:25) [Clang 13.0.0 (clang-1300.0.29.30)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.__version__
'0.3.20'
>>> import jaxlib
>>> jaxlib.__version__
'0.3.20'
>>>
zhangqiaorjc-macbookpro2:~ zhangqiaorjc$ pip install --upgrade pip
Requirement already satisfied: pip in ./.pyenv/versions/3.10.1/lib/python3.10/site-packages (22.3)
Collecting pip
  Downloading pip-22.3.1-py3-none-any.whl (2.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 10.9 MB/s eta 0:00:00
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 22.3
    Uninstalling pip-22.3:
      Successfully uninstalled pip-22.3
Successfully installed pip-22.3.1
zhangqiaorjc-macbookpro2:~ zhangqiaorjc$
zhangqiaorjc-macbookpro2:~ zhangqiaorjc$ pip install --upgrade "jax[cpu]"
Requirement already satisfied: jax[cpu] in ./.pyenv/versions/3.10.1/lib/python3.10/site-packages (0.3.20)
Collecting jax[cpu]
  Downloading jax-0.3.25.tar.gz (1.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 9.1 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Requirement already satisfied: numpy>=1.20 in ./.pyenv/versions/3.10.1/lib/python3.10/site-packages (from jax[cpu]) (1.22.2)
Requirement already satisfied: opt_einsum in ./.pyenv/versions/3.10.1/lib/python3.10/site-packages (from jax[cpu]) (3.3.0)
Requirement already satisfied: scipy>=1.5 in ./.pyenv/versions/3.10.1/lib/python3.10/site-packages (from jax[cpu]) (1.8.0)
Requirement already satisfied: typing_extensions in ./.pyenv/versions/3.10.1/lib/python3.10/site-packages (from jax[cpu]) (4.1.1)
Collecting jaxlib==0.3.25
  Downloading jaxlib-0.3.25-cp310-cp310-macosx_11_0_arm64.whl (51.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 51.6/51.6 MB 33.2 MB/s eta 0:00:00
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... done
  Created wheel for jax: filename=jax-0.3.25-py3-none-any.whl size=1308509 sha256=94ad7e9a1e0cb168b6831b9540b7f1e07020a663e6b3044b51651cdd37550013
  Stored in directory: /Users/zhangqiaorjc/Library/Caches/pip/wheels/45/65/e5/9d8dd300a2181533ee4d74780d647245740bbff982b700903a
Successfully built jax
Installing collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.3.20
    Uninstalling jaxlib-0.3.20:
      Successfully uninstalled jaxlib-0.3.20
  Attempting uninstall: jax
    Found existing installation: jax 0.3.20
    Uninstalling jax-0.3.20:
      Successfully uninstalled jax-0.3.20
Successfully installed jax-0.3.25 jaxlib-0.3.25
zhangqiaorjc-macbookpro2:~ zhangqiaorjc$ python
Python 3.10.1 (main, Mar  7 2022, 13:38:25) [Clang 13.0.0 (clang-1300.0.29.30)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.__version__
'0.3.25'
>>> import jaxlib
>>> jaxlib.__version__
'0.3.25'
>>> jax.device_put(1)+1
DeviceArray(2, dtype=int32, weak_type=True)

Can you try upgrade pip and also use the following instructions
https://github.com/google/jax#pip-installation-cpu

If it still doesn't work, post the commands you used? also maybe the python version and other environment related things.

@hawkinsp
Copy link
Collaborator

Are you using a Linux ARM64 image in docker? If that's the case, it's simply that we haven't released Linux ARM wheels yet (only Mac), see #7097

You can fix this by building jaxlib from source.

Eventually we are likely to release Linux aarch64 wheels, but we haven't done so at this time.

I hope that helps!

@ridwan-salau
Copy link

The answer here solves the problem easily for me and many others (as seen in the number of upvotes)

@dexianta
Copy link

The answer here solves the problem easily for me and many others (as seen in the number of upvotes)

If you use Linux/amd64 image on m1, indeed you can install jaxlib. But you won't be able to import JAX. As it will throw errors about the AVX instructions.

@hawkinsp
Copy link
Collaborator

@dexianta Yes, that's as expected. The x86-64 (amd64) build of JAX uses AVX instructions, which Rosetta doesn't support. You either need to build an x86-64 jaxlib without AVX support, or you need to use an aarch64 build of jaxlib. Both require that you build jaxlib from source at the moment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants