Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax doesn't see my GPU, even though Pytorch does #5231

Closed
asemic-horizon opened this issue Dec 20, 2020 · 19 comments · Fixed by #5394
Closed

Jax doesn't see my GPU, even though Pytorch does #5231

asemic-horizon opened this issue Dec 20, 2020 · 19 comments · Fixed by #5394
Assignees

Comments

@asemic-horizon
Copy link

Jax sounds like an impressive project, thanks for working on it.

That said: on Ubuntu 18.04, this happens

➜  python
Python 3.6.9 (default, Oct  8 2020, 12:12:24) 
[GCC 8.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.

>>> import torch, jax; print(torch.cuda.is_available()); print(jax.devices())
True
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]

I first tried to pip install jax and got various errors; the error messages said it was common with old versions of pip that didn't support newer kinds of wheel and directed me to upgrade pip, which I did (from 9 to 20). Now Jax seems to be installed (at least various numpy-compatible functions do), but not to the point where it appears to see my GPU, a laptop "Geforce" card by Nvidia.

I'm not sure what systems diagnostics I can bring to help. This is the name of my card:

✗  lspci | grep NVIDIA
01:00.0 3D controller: NVIDIA Corporation GP108M [GeForce MX150] (rev a1)

As best as I understand it, these are drivers:

✗  lsmod | grep nvidia
nvidia_uvm            970752  0
nvidia_drm             53248  3
nvidia_modeset       1212416  2 nvidia_drm
nvidia              27643904  103 nvidia_uvm,nvidia_modeset
drm_kms_helper        184320  2 nvidia_drm,i915
drm                   491520  8 drm_kms_helper,nvidia_drm,i915

Thanks for reading this anyway.

@8bitmp3
Copy link
Contributor

8bitmp3 commented Dec 20, 2020

Hello 👋 Just to confirm - did you follow the Linux-specific installation instructions from the README? Also, have you tried installing JAX in a separate virtual environment that excludes PyTorch? 🤷‍♀️

https://github.com/google/jax#installation

On Linux, it is often necessary to first update pip to a version that supports manylinux2010 wheels.

If you want to install JAX with both CPU and GPU support, using existing CUDA and CUDNN7 installations on your machine (for example, preinstalled on your cloud VM), you can run

pip install --upgrade pip
pip install --upgrade jax jaxlib==0.1.57+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

The jaxlib version must correspond to the version of the existing CUDA installation you want to use, with cuda110 for CUDA 11.0, cuda102 for CUDA 10.2, and cuda101 for CUDA 10.1. You can find your CUDA version with: install path:

nvcc --version

Note that some GPU functionality expects the CUDA installation to be at /usr/local/cuda-X.X, where X.X should be replaced with the CUDA version number (e.g. cuda-10.2). If CUDA is installed elsewhere on your system, you can either create a symlink:

sudo ln -s /path/to/cuda /usr/local/cuda-X.X

Or set the following environment variable before importing JAX:

XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda

@GJBoth
Copy link
Contributor

GJBoth commented Dec 21, 2020

I had the same issue, but managed to solve it. It seems pytorch bundles its own cuda, so that's why you don't have to install it separately but it sees your gpu and nvidia-smi works. Installling cuda for your GPU following these instructions solved the issue for me: https://developer.nvidia.com/cuda-downloads

@8bitmp3
Copy link
Contributor

8bitmp3 commented Dec 21, 2020

@GJBoth That's awesome. Can you confirm the following (and please correct me if I'm wrong):

  • Having JAX in a separate env should technically help identify if JAX can detect your CUDA (and not PyTorch's bundled one)
  • The current instructions assume that you've taken care of your CUDA installation (see extract below) but maybe it would help to nudge the users to go to https://developer.nvidia.com/cuda-downloads and install CUDA, if they haven't already.

"... using existing CUDA and CUDNN7 installations on your machine (for example, preinstalled on your cloud VM)..."
...
"Note that some GPU functionality expects the CUDA installation to be at /usr/local/cuda-X.X, where X.X should be replaced with the CUDA version number (e.g. cuda-10.2). If CUDA is installed elsewhere on your system..."

WDYT? @asemic-horizon @GJBoth

@GJBoth
Copy link
Contributor

GJBoth commented Dec 21, 2020

I have no experience installing cuda in a specific env. It seems that the symbolic link wouldn't work, see this thread:
https://discuss.pytorch.org/t/where-is-cudatoolkit-path-when-installed-via-conda/47791/5

I think many new jax users will come from pytorch, so adding a nudge a la 'If you're coming from pytorch, make sure to install cuda separately, if you haven't yet.') Two more observations: I could run nvidia-smi, but not nvcc, so this might be a nice check to see if you have pytorch cuda or systemwide. Furthermore, jupyter notebooks tend to die silently with these issues, so running things as a script gives you much more info.

@myjr52
Copy link

myjr52 commented Mar 21, 2021

I had the same problem with jax not recognizing GPU. I did the following two steps:

pip install --upgrade pip
pip install --upgrade jax jaxlib==0.1.57+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

and make a link as follows:

ln -s /usr/lib/nvidia-cuda-toolkit /usr/local/cuda-10.1

After this, jax still didn't recognize GPU. Then, I did the following steps hinted from the warning message in jax about GPU:

cd /usr/lib/nvidia-cuda-toolkit
mkdir nvvm
cd nvvm
sudo ln -s /usr/lib/nvidia-cuda-toolkit/libdevice libdevice

You would need to use "sudo" for the above steps. After these, jax recognises my GPU.

@morawi
Copy link

morawi commented Apr 16, 2021

sudo ln -s /usr/lib/nvidia-cuda-toolkit/libdevice libdevice

I am remotely connected to a slurm cluster and do not have sudo rights. In fact, I do not even have permission to make a symbolic link.
Plus, my environment has no GPU, the GPU is assigned via the sbatch job file using the command "#SBATCH --gres=gpu:1" .
This is way too complicated. Yet, PyTorch seems to work perfectly well.

@yongjiezhu
Copy link

sudo ln -s /usr/lib/nvidia-cuda-toolkit/libdevice libdevice

I am remotely connected to a slurm cluster and do not have sudo rights. In fact, I do not even have permission to make a symbolic link.
Plus, my environment has no GPU, the GPU is assigned via the sbatch job file using the command "#SBATCH --gres=gpu:1" .
This is way too complicated. Yet, PyTorch seems to work perfectly well.

Hi, I have the same problem with remotely connected to a slurm cluster. Do you solve this issue? How to solve it?

@morawi
Copy link

morawi commented Apr 27, 2021

Do you solve this issue?

Unfortunately not. I don't have sudo control over the cluster and this makes it hard. The best way for JAX is that they ship a cuda bundle with the installation, similar to PyTorch.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Apr 28, 2021

FYI #6581 bundles libdevice.10.bc with jaxlib wheels, which hopefully will help avoid this particular problem. If you are feeling motivated, you could try patching in that PR and building a jaxlib from source to see if it fixes your problems.

@josephrocca
Copy link
Contributor

josephrocca commented Sep 3, 2021

Similar to myjr52, I was able to solve this simply by replacing this:

pip install --upgrade jax jaxlib

with this (you'll have the change cuda111 based on your output of nvcc --version - mine is 11.1):

pip install --upgrade jax jaxlib==0.1.69+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

I didn't need to do any of the additional steps mentioned by myjr52.

@Oaklight
Copy link

Guess it's a bit late for this. But I got mine fixed by specifying the exact whl link found in the https://storage.googleapis.com/jax-releases/jax_releases.html. Just I need cuda 11.0.

The one I used was:

pip uninstall jax jaxlib -y
pip install https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.71+cuda110-cp38-none-manylinux2010_x86_64.whl

@jithendaraa
Copy link

Do you solve this issue?

Unfortunately not. I don't have sudo control over the cluster and this makes it hard. The best way for JAX is that they ship a cuda bundle with the installation, similar to PyTorch.

@morawi Curious to know if this is solved for you yet since I'm going through the same thing with JAX on a slurm cluster

@morawi
Copy link

morawi commented Nov 13, 2021

Do you solve this issue?

Unfortunately not. I don't have sudo control over the cluster and this makes it hard. The best way for JAX is that they ship a cuda bundle with the installation, similar to PyTorch.

@morawi Curious to know if this is solved for you yet since I'm going through the same thing with JAX on a slurm cluster

I just stopped using it.

@ritog
Copy link

ritog commented Dec 26, 2021

I have solved this problem very easily just following this issue, some googling and stumbling upon two SO questions, and the readme of this project.

I had nvidia drivers installed in my laptop through the Pop OS store, and I installed nvidia-cuda-toolkit through apt, and then installed PyTorch (earlier).

My cuda version is 11.2.

I did not have to do any other installation.

  1. I upgraded pip.
  2. I installed jax[cuda11] instead if just jax.
  3. Followed other generic instructions from the ReadMe.
  4. Created two symlinks- one for nvcc and another for cuda.

Then it started working great.

@juanfolco
Copy link

If someone else stumbles into this, the CUDA wheel releases are now stored on https://storage.googleapis.com/jax-releases/jax_cuda_releases.html for some reason.

@FrancescoSaverioZuppichini
Copy link

FrancescoSaverioZuppichini commented Oct 10, 2022

Same issue, using

FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
...
RUN python3 -m pipinstall "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
... 

Jax doesn't see the gpu

import jax
print(jax.devices())

only cpu

@aquaresima
Copy link

Hello,

same issue here! Torch can find my GPU, JAX does not!

nvcc --version                                                                             [16:16:43]
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0

pacman -Q cudnn                                                                            [16:37:00]
cudnn 8.5.0.96-1

I followed the installation guide here:

When I run
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
It first asks me for the storage.googleapis.com user (WTF ??)
then I just press enter and it return an access error!

Collecting jaxlib==0.3.22+cuda11.cudnn82
User for storage.googleapis.com:   WARNING: 401 Error, Credentials not correct for https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.22%2Bcuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl
  ERROR: HTTP error 401 while getting https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.22%2Bcuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl (from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html)
ERROR: Could not install requirement jaxlib==0.3.22+cuda11.cudnn82 from https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.22%2Bcuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl (from jax[cuda]) because of HTTP error 401 Client Error: Unauthorized for url: https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.22%2Bcuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl for URL https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.22%2Bcuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl (from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html)
FAIL: 1


If I run

pip install --upgrade "jax[cuda]" or pip install "jax[cuda11_cudnn82]"

Then it unrolls the version till the very first that has not cuda! (WTF^2)

....
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.25.tar.gz (786 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.24.tar.gz (786 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.22-py3-none-any.whl
WARNING: jax 0.2.22 does not provide the extra 'cuda'
Installing collected packages: jax
  Attempting uninstall: jax
    Found existing installation: jax 0.3.22
    Uninstalling jax-0.3.22:
      Successfully uninstalled jax-0.3.22
Successfully installed jax-0.2.22

The pip install --upgrade "jax[all]" runs just fine: Successfully installed jax-0.3.22
but GPU access is not available (see topmost).

Thanks for the help

@aquaresima
Copy link

Update:

If I download my appropriate version

[cuda11/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-manylinux2014_x86_64.whl](https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-manylinux2014_x86_64.whl)

and run:

pip install --upgrade "jax[cuda11_cudnn82]" -f ~/Downloads/jax

Here we go again:

[16:53:15]
Looking in links: /home/cocconat/Downloads/jax
Requirement already satisfied: jax[cuda11_cudnn82] in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (0.3.22)
Requirement already satisfied: absl-py in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (from jax[cuda11_cudnn82]) (1.2.0)
Requirement already satisfied: opt-einsum in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (from jax[cuda11_cudnn82]) (3.3.0)
Requirement already satisfied: numpy>=1.20 in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (from jax[cuda11_cudnn82]) (1.23.3)
Requirement already satisfied: etils[epath] in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (from jax[cuda11_cudnn82]) (0.8.0)
Requirement already satisfied: typing-extensions in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (from jax[cuda11_cudnn82]) (4.3.0)
Requirement already satisfied: scipy>=1.5 in /home/cocconat/.virtualenvs/learning/lib/python3.10/site-packages (from jax[cuda11_cudnn82]) (1.9.1)
Collecting jax[cuda11_cudnn82]
  Using cached jax-0.3.22-py3-none-any.whl
  Using cached jax-0.3.21.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.20.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.19.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.17-py3-none-any.whl
  Using cached jax-0.3.16.tar.gz (1.0 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.15.tar.gz (1.0 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.14.tar.gz (990 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.13.tar.gz (951 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.12.tar.gz (947 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.11.tar.gz (947 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.10.tar.gz (939 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.9.tar.gz (937 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.8.tar.gz (935 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.7.tar.gz (944 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.6.tar.gz (936 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.5.tar.gz (946 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.4.tar.gz (924 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.3.tar.gz (924 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.2.tar.gz (926 kB)
^C  Preparing metadata (setup.py) ... canceled
ERROR: Operation cancelled by user
FAIL: 1


@hawkinsp
Copy link
Collaborator

@aquaresima Please open a new issue, please don't ping long-closed issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.