-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Significant difference in empirical NTK for batched and non-batched versions #122
Comments
Thanks for reaching out! Unfortunately I wasn't able to reproduce the issue you're having. Here's a colab notebook (note: I had to make the model slightly smaller to fit in public colab GPU memory). All of the versions seem similar to your current environment. However, I notice your gpu drivers seem a little out-of-date (version 10.2 instead of version 11.2) is it possible that is the issue? |
Btw also consider trying |
Thanks for the quick answers! Curiously, it seems that using On the other hand, I can also not reproduce this strange behaviour on the Google colab. I will try to update my cuda version, and get back to you. |
After updating to CUDA-11.4, I can confirm that the issue indeed only happens on the old CUDA version. With this new version, both In fact, it seems that support for CUDA-10.2 will fade in the next release of JAX. Even if this happens soon, I would still recommend directly specifying CUDA-11.x as a dependency of In any case, thank you very much for all your help! You were really helpful. |
Thanks a lot for figuring this out! Just pushed a release (https://github.com/google/neural-tangents/releases/tag/v0.3.9) bumping up our minimum JAX version to 0.2.25, which itself should only work with CUDA-11 and higher, so hopefully this should be fixed! Please feel free to re-open if the issue remains |
After updating my environment to work with a more recent version of JAX and FLAX, I have noticed that empirical the NTK Gram matrices computed using
nt.batch
applied tont.empirical_kernel_fn
are significantly different depending on the batch size.The code to reproduce this error is:
Surprisingly, if I run this with my old environment I get an average error of the order of
1e-8
, while with the new environment the error is of the order of1e-1
. Also, this error remains exactly the same as long asbatch_size<data.shape[0]
.My old enviornment consisted of:
and my new environment has:
The text was updated successfully, but these errors were encountered: