Skip to content

Commit

Permalink
Add an NTK-vector product function (without instantiating the NTK).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 473147728
  • Loading branch information
romanngg committed Sep 9, 2022
1 parent d289069 commit 628ce0e
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 29 deletions.
9 changes: 9 additions & 0 deletions docs/empirical.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ An :class:`enum.IntEnum` specifying NTK implementation method.

.. autoclass:: NtkImplementation

NTK-vector products
--------------------------------------
A function to compute NTK-vector products without instantiating the NTK.

.. autosummary::
:toctree: _autosummary

empirical_ntk_vp_fn

Linearization and Taylor expansion
--------------------------------------
Decorators to Taylor-expand around function parameters.
Expand Down
3 changes: 2 additions & 1 deletion neural_tangents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Public Neural Tangents modules and functions."""


__version__ = '0.6.0'
__version__ = '0.6.1'

from . import experimental
from . import predict
Expand All @@ -25,6 +25,7 @@
from ._src.empirical import empirical_kernel_fn
from ._src.empirical import empirical_nngp_fn
from ._src.empirical import empirical_ntk_fn
from ._src.empirical import empirical_ntk_vp_fn
from ._src.empirical import linearize
from ._src.empirical import NtkImplementation
from ._src.empirical import taylor_expand
Expand Down
104 changes: 104 additions & 0 deletions neural_tangents/_src/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,110 @@ def kernel_fn(
return kernel_fn


# NTK-VECTOR PRODUCT FUNCTION


def empirical_ntk_vp_fn(
f: ApplyFn,
x1: PyTree,
x2: Optional[PyTree],
params: PyTree,
**apply_fn_kwargs
) -> Callable[[PyTree], PyTree]:
"""Returns an NTK-vector product function.
The function computes NTK-vector product without instantiating the NTK, and
has the runtime equivalent to `(N1 + N2)` forward passes through `f`, and
memory equivalent to evaluating a vector-Jacobian product of `f`.
For details, please see section L of "`Fast Finite Width Neural Tangent Kernel
<https://arxiv.org/abs/2206.08720>`_".
Example:
>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> k1, k2, k3, k4 = random.split(random.PRNGKey(1), 4)
>>> x1 = random.normal(k1, (20, 32, 32, 3))
>>> x2 = random.normal(k2, (10, 32, 32, 3))
>>> #
>>> # Define a forward-pass function `f`.
>>> init_fn, f, _ = stax.serial(
>>> stax.Conv(32, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(32, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(32, (3, 3)),
>>> stax.Flatten(),
>>> stax.Dense(10)
>>> )
>>> #
>>> # Initialize parameters.
>>> _, params = init_fn(k3, x1.shape)
>>> #
>>> # NTK-vp function. Can/should be JITted.
>>> ntk_vp_fn = empirical_ntk_vp_fn(f, x1, x2, params)
>>> #
>>> # Cotangent vector
>>> cotangents = random.normal(k4, f(params, x2).shape)
>>> #
>>> # NTK-vp output
>>> ntk_vp = ntk_vp_fn(cotangents)
>>> #
>>> # Output has same shape as `f(params, x1)`.
>>> assert ntk_vp.shape == f(params, x1).shape
Args:
f:
forward-pass function of signature `f(params, x)`.
x1:
first batch of inputs.
x2:
second batch of inputs. `x2=None` means `x2=x1`.
params:
A `PyTree` of parameters about which we would like to compute the
neural tangent kernel.
**apply_fn_kwargs:
keyword arguments passed to `f`. `apply_fn_kwargs` will be split into
`apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function
which will be passed to `f`. In particular, the rng key in
`apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same
(if `x1==x2`) rng keys. See the `_read_key` function for more details.
Returns:
An NTK-vector product function accepting a `PyTree` of cotangents of shape
and structure of `f(params, x2)`, and returning the NTK-vector product of
shape and structure of `f(params, x1)`.
"""
args1, args2, fx1, fx2, fx_axis, keys, kw_axes, x_axis = _get_args(
f, apply_fn_kwargs, params, None, x1, x2)

f1, f2 = _get_f1_f2(f, keys, x_axis, fx_axis, kw_axes, args1 + args2, x1, x2)

def ntk_vp_fn(cotangents: PyTree) -> PyTree:
"""Computes a single empirical NTK-vector product.
Args:
cotangents:
a `PyTree` of cotangents. Must have the same shape and tree structure
as `f(params, x2)`.
Returns:
A single NTK-vector product of shape and tree structure of
`f(params, x1)`.
"""
vjp_out = vjp(f2, params)[1](cotangents)
jvp_out = jvp(f1, (params,), vjp_out)[1]
return jvp_out

return ntk_vp_fn


# INTERNAL UTILITIES


Expand Down
109 changes: 81 additions & 28 deletions tests/empirical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,34 +1332,34 @@ def _get_mixer_b16_config() -> Dict[str, Any]:


@test_utils.product(
j_rules=[
True,
False
],
s_rules=[
True,
# False
],
fwd=[
True,
False,
None,
],
same_inputs=[
# True,
False
],
do_jit=[
True,
# False
],
do_remat=[
# True,
False
],
dtype=[
jax.dtypes.canonicalize_dtype(np.float64),
]
j_rules=[
True,
False
],
s_rules=[
True,
# False
],
fwd=[
True,
False,
None,
],
same_inputs=[
# True,
False
],
do_jit=[
True,
# False
],
do_remat=[
# True,
False
],
dtype=[
jax.dtypes.canonicalize_dtype(np.float64),
]
)
class FlaxOtherTest(test_utils.NeuralTangentsTestCase):

Expand Down Expand Up @@ -1629,5 +1629,58 @@ def f(p, x):
vmap_axes=vmap_axes)


class EmpiricalNtkVpTest(test_utils.NeuralTangentsTestCase):

@test_utils.product(
same_inputs=[
True,
False
],
do_jit=[
True,
False
],
)
def test_ntk_vp_fn(
self,
same_inputs,
do_jit,
):
N1 = 4
N2 = N1 if same_inputs else 6
O = 3

init_fn, f, _ = stax.serial(
stax.Dense(8),
stax.Relu(),
stax.Dense(O)
)

k1, k2, k3, k4 = random.split(random.PRNGKey(1), 4)
x1 = random.normal(k1, (N1, 7))
x2 = None if same_inputs else random.normal(k2, (N2, 7))
_, params = init_fn(k3, x1.shape)

ntk_ref = nt.empirical_ntk_fn(f, (), vmap_axes=0)(x1, x2, params)
ntk_ref = np.moveaxis(ntk_ref, 1, 2)

# Compute an NTK via NTK-vps and compare to the reference
ntk_vp_fn = nt.empirical_ntk_vp_fn(f, x1, x2, params)
if do_jit:
ntk_vp_fn = jit(ntk_vp_fn)

eye = np.eye(N2 * O).reshape((N2 * O, N2, O))
ntk_vps = jit(jax.vmap(ntk_vp_fn))(eye)
ntk_vps = np.moveaxis(ntk_vps, (0,), (2,))
ntk_vps = ntk_vps.reshape((N1, O, N2, O))
self.assertAllClose(ntk_ref, ntk_vps)

# Compute a single NTK-vp via reference NTK, and compare to the NTK-vp.
cotangents = random.normal(k4, f(params, x1 if same_inputs else x2).shape)
ntk_vp_ref = np.tensordot(ntk_ref, cotangents, ((2, 3), (0, 1)))
ntk_vp = ntk_vp_fn(cotangents)
self.assertAllClose(ntk_vp_ref, ntk_vp)


if __name__ == '__main__':
absltest.main()

0 comments on commit 628ce0e

Please sign in to comment.