Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significant difference in empirical NTK for batched and non-batched versions #122

Closed
gortizji opened this issue Sep 2, 2021 · 5 comments
Labels
bug Something isn't working

Comments

@gortizji
Copy link

gortizji commented Sep 2, 2021

After updating my environment to work with a more recent version of JAX and FLAX, I have noticed that empirical the NTK Gram matrices computed using nt.batch applied to nt.empirical_kernel_fn are significantly different depending on the batch size.

The code to reproduce this error is:

import jax.numpy as jnp
import flax.linen as nn
import functools
import jax
import neural_tangents as nt


class LeNet(nn.Module):
    kernel_size = (5, 5)
    strides = (2, 2)
    window_shape = (2, 2)
    num_classes = 1
    features = (6, 16, 120, 84, 1)
    pooling = True
    padding = "SAME"

    @nn.compact
    def __call__(self, x):
        conv = functools.partial(nn.Conv, padding=self.padding)
        x = conv(features=self.features[0], kernel_size=tuple(self.kernel_size))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=tuple(self.window_shape), strides=tuple(self.strides))

        x = conv(features=self.features[1], kernel_size=tuple(self.kernel_size))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=tuple(self.window_shape), strides=tuple(self.strides))

        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(self.features[2])(x)
        x = nn.relu(x)
        x = nn.Dense(self.features[3])(x)
        x = nn.relu(x)

        x = nn.Dense(self.num_classes)(x)
        return x

model_key, data_key = jax.random.split(jax.random.PRNGKey(42))
data = jax.random.normal(data_key, [500, 32, 32, 3])
model = LeNet()
init_params = model.init(model_key, jnp.zeros([1, 32, 32, 3]))

# Compute NTK Gram matrix using the fully parallel version
kernel_full_fn = nt.batch(
    nt.empirical_kernel_fn(model.apply, vmap_axes=0, implementation=2, trace_axes=()),
    batch_size=500,
    device_count=-1,
    store_on_device=False,
)
K_full = kernel_full_fn(data, None, "ntk", init_params)

# Compute NTK Gram matrix using minibatches
kernel_batch_fn = nt.batch(
    nt.empirical_kernel_fn(model.apply, vmap_axes=0, implementation=2, trace_axes=()),
    batch_size=100,
    device_count=-1,
    store_on_device=False,
)
K_batch = kernel_batch_fn(data, None, "ntk", init_params)

# Compute difference between two matrices. It should technically be 0.
print("Average error per entry:",  jnp.linalg.norm(K_full - K_batch) / K_full.size)

Surprisingly, if I run this with my old environment I get an average error of the order of 1e-8, while with the new environment the error is of the order of 1e-1. Also, this error remains exactly the same as long as batch_size<data.shape[0].

My old enviornment consisted of:

python=3.7.4
jax=0.2.8
jaxlib=0.1.57+cuda102
flax=0.3.0
neural-tangents=0.3.7

and my new environment has:

python=3.7.4
jax=0.2.19
jaxlib=0.1.70+cuda102
flax=0.3.4
neural-tangents=0.3.7
@sschoenholz
Copy link

Thanks for reaching out! Unfortunately I wasn't able to reproduce the issue you're having. Here's a colab notebook (note: I had to make the model slightly smaller to fit in public colab GPU memory).

All of the versions seem similar to your current environment. However, I notice your gpu drivers seem a little out-of-date (version 10.2 instead of version 11.2) is it possible that is the issue?

@romanngg
Copy link
Contributor

romanngg commented Sep 2, 2021

Btw also consider trying implementation=1, we often find it faster than 2 for convolutions, especially if you only have one output logit (2 scales better with number of logits, but this shouldn't matter if you only have one).

@romanngg romanngg added the bug Something isn't working label Sep 2, 2021
@gortizji
Copy link
Author

gortizji commented Sep 2, 2021

Thanks for the quick answers! Curiously, it seems that using implementation=1 does make things more stable, i.e., error of around 1e-8. Still for small batches of around 10 the error climbs up to 1e-5, but it is definitely not 0.1.

On the other hand, I can also not reproduce this strange behaviour on the Google colab. I will try to update my cuda version, and get back to you.

@gortizji
Copy link
Author

After updating to CUDA-11.4, I can confirm that the issue indeed only happens on the old CUDA version. With this new version, both implementation=1 and implementation=2 yield an error of the order of 1e-8, regardless of the batch_size.

In fact, it seems that support for CUDA-10.2 will fade in the next release of JAX. Even if this happens soon, I would still recommend directly specifying CUDA-11.x as a dependency of neural-tangents. I do not know what was the root cause for that very strange behaviour, but I am worried it might silently break other functionalities of the library when using it with the wrong CUDA version.

In any case, thank you very much for all your help! You were really helpful.

romanngg added a commit that referenced this issue Nov 17, 2021
…ax.example_libraries`. Bump requirement to JAX v0.2.25 to avoid the CUDA-10 bug in #122 (thanks @gortizji for pointing this out!)

PiperOrigin-RevId: 409029167
@romanngg
Copy link
Contributor

romanngg commented Nov 17, 2021

Thanks a lot for figuring this out! Just pushed a release (https://github.com/google/neural-tangents/releases/tag/v0.3.9) bumping up our minimum JAX version to 0.2.25, which itself should only work with CUDA-11 and higher, so hopefully this should be fixed! Please feel free to re-open if the issue remains

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants