-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax doesn't see my GPU, even though Pytorch does #5231
Comments
Hello 👋 Just to confirm - did you follow the Linux-specific installation instructions from the README? Also, have you tried installing JAX in a separate virtual environment that excludes PyTorch? 🤷♀️ https://github.com/google/jax#installation
|
I had the same issue, but managed to solve it. It seems pytorch bundles its own cuda, so that's why you don't have to install it separately but it sees your gpu and nvidia-smi works. Installling cuda for your GPU following these instructions solved the issue for me: https://developer.nvidia.com/cuda-downloads |
@GJBoth That's awesome. Can you confirm the following (and please correct me if I'm wrong):
WDYT? @asemic-horizon @GJBoth |
I have no experience installing cuda in a specific env. It seems that the symbolic link wouldn't work, see this thread: I think many new jax users will come from pytorch, so adding a nudge a la 'If you're coming from pytorch, make sure to install cuda separately, if you haven't yet.') Two more observations: I could run nvidia-smi, but not nvcc, so this might be a nice check to see if you have pytorch cuda or systemwide. Furthermore, jupyter notebooks tend to die silently with these issues, so running things as a script gives you much more info. |
I had the same problem with jax not recognizing GPU. I did the following two steps: pip install --upgrade pip and make a link as follows: ln -s /usr/lib/nvidia-cuda-toolkit /usr/local/cuda-10.1 After this, jax still didn't recognize GPU. Then, I did the following steps hinted from the warning message in jax about GPU: cd /usr/lib/nvidia-cuda-toolkit You would need to use "sudo" for the above steps. After these, jax recognises my GPU. |
I am remotely connected to a slurm cluster and do not have sudo rights. In fact, I do not even have permission to make a symbolic link. |
Hi, I have the same problem with remotely connected to a slurm cluster. Do you solve this issue? How to solve it? |
Unfortunately not. I don't have sudo control over the cluster and this makes it hard. The best way for JAX is that they ship a cuda bundle with the installation, similar to PyTorch. |
FYI #6581 bundles |
Similar to myjr52, I was able to solve this simply by replacing this:
with this (you'll have the change
I didn't need to do any of the additional steps mentioned by myjr52. |
Guess it's a bit late for this. But I got mine fixed by specifying the exact whl link found in the https://storage.googleapis.com/jax-releases/jax_releases.html. Just I need cuda 11.0. The one I used was:
|
@morawi Curious to know if this is solved for you yet since I'm going through the same thing with JAX on a slurm cluster |
I just stopped using it. |
I have solved this problem very easily just following this issue, some googling and stumbling upon two SO questions, and the readme of this project. I had nvidia drivers installed in my laptop through the Pop OS store, and I installed My cuda version is 11.2. I did not have to do any other installation.
Then it started working great. |
If someone else stumbles into this, the CUDA wheel releases are now stored on https://storage.googleapis.com/jax-releases/jax_cuda_releases.html for some reason. |
Same issue, using
Jax doesn't see the gpu import jax
print(jax.devices()) only cpu |
Hello, same issue here! Torch can find my GPU, JAX does not!
I followed the installation guide here: When I run
If I run
Then it unrolls the version till the very first that has not cuda! (WTF^2)
The Thanks for the help |
Update: If I download my appropriate version
and run:
Here we go again:
|
@aquaresima Please open a new issue, please don't ping long-closed issues. |
Jax sounds like an impressive project, thanks for working on it.
That said: on Ubuntu 18.04, this happens
I first tried to pip install jax and got various errors; the error messages said it was common with old versions of pip that didn't support newer kinds of wheel and directed me to upgrade pip, which I did (from 9 to 20). Now Jax seems to be installed (at least various numpy-compatible functions do), but not to the point where it appears to see my GPU, a laptop "Geforce" card by Nvidia.
I'm not sure what systems diagnostics I can bring to help. This is the name of my card:
As best as I understand it, these are drivers:
Thanks for reading this anyway.
The text was updated successfully, but these errors were encountered: