Skip to content

Commit

Permalink
Refactoring: a long awaited refactor that splits the huge stax.py i…
Browse files Browse the repository at this point in the history
…nto subcomponents, among other things:

(1) Move implementations into `_src`, and import public functions from sources into the top-level modules. This makes it easier to manage, remember, and view our public API. Note that as a downside this makes it less convenient to "hack" our library, where users might want to use our private functions or modules of the library.

(2) Split `stax` into 4 parts: `requirements`, `elementwise`, `linear`, and `combinators`. This allows to better understand the structure of `stax` and make it easier to browse / implement new layers or combinators, unless they don't fall squarely into any category.

(5) Move out `test_utils` from the library and into the tests folder only. Decouple more tests/stex/outside users from library internals.

(6) Rename `batch` into `batching` to avoid confusing module with the function.

(7) Remove dependence on `jax.lib.xla_bridge`.

PiperOrigin-RevId: 429185591
  • Loading branch information
romanngg committed Feb 17, 2022
1 parent b9c2c57 commit 4b183ce
Show file tree
Hide file tree
Showing 45 changed files with 13,257 additions and 12,684 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,4 @@
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ The `neural_tangents` (`nt`) package contains the following modules and function

* `monte_carlo_kernel_fn` - compute a Monte Carlo kernel estimate of _any_ `(init_fn, apply_fn)`, not necessarily specified via `nt.stax`, enabling the kernel computation of infinite networks without closed-form expressions.

* Tools to investigate training dynamics of _wide but finite_ neural networks, like `linearize`, `taylor_expand`, `empirical_kernel_fn` and more. See [Training dynamics of wide but finite networks](#training-dynamics-of-wide-but-finite-networks) for details.
* Tools to investigate training dynamics of _wide but finite_ neural networks, like `linearize`, `taylor_expand`, `empirical.kernel_fn` and more. See [Training dynamics of wide but finite networks](#training-dynamics-of-wide-but-finite-networks) for details.


## Technical gotchas
Expand Down Expand Up @@ -311,10 +311,12 @@ import jax.random as random
import jax.numpy as np
import neural_tangents as nt


def apply_fn(params, x):
W, b = params
return np.dot(x, W) + b


W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))
params = (W_0, b_0)
Expand Down
6 changes: 3 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ neural networks.
:caption: Topics:

neural_tangents.stax
neural_tangents.empirical
neural_tangents._src.empirical
neural_tangents.predict
neural_tangents.batching
neural_tangents.monte_carlo
neural_tangents._src.batching
neural_tangents._src.monte_carlo

Indices and tables
==================
Expand Down
2 changes: 1 addition & 1 deletion docs/neural_tangents.batching.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ Batching
===========================

.. default-role:: code
.. automodule:: neural_tangents.utils.batch
.. automodule:: neural_tangents._src.batching
:members:
2 changes: 1 addition & 1 deletion docs/neural_tangents.empirical.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ Empirical
===========================

.. default-role:: code
.. automodule:: neural_tangents.utils.empirical
.. automodule:: neural_tangents._src.empirical
:members:
2 changes: 1 addition & 1 deletion docs/neural_tangents.monte_carlo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ Monte Carlo Sampling
===========================

.. default-role:: code
.. automodule:: neural_tangents.utils.monte_carlo
.. automodule:: neural_tangents._src.monte_carlo
:members:
2 changes: 1 addition & 1 deletion examples/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
7 changes: 5 additions & 2 deletions examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,11 @@ def minibatch(x_train, y_train, batch_size, train_epochs):

if end > x_train.shape[0]:
key, split = random.split(key)
permutation = random.shuffle(split,
np.arange(x_train.shape[0], dtype=np.int64))
permutation = random.permutation(
split,
np.arange(x_train.shape[0], dtype=np.int64),
independent=True
)
x_train = x_train[permutation]
y_train = y_train[permutation]
epoch += 1
Expand Down
2 changes: 1 addition & 1 deletion examples/weight_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@


def main(unused_argv):
# Build data and .
# Load data and preprocess it.
print('Loading data.')
x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
permute_train=True)
Expand Down
20 changes: 10 additions & 10 deletions neural_tangents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
"""Public Neural Tangents modules and functions."""


__version__ = '0.4.0'
__version__ = '0.5.0'


from neural_tangents import predict
from neural_tangents import stax
from neural_tangents.utils.batch import batch
from neural_tangents.utils.empirical import empirical_kernel_fn
from neural_tangents.utils.empirical import empirical_nngp_fn
from neural_tangents.utils.empirical import empirical_ntk_fn
from neural_tangents.utils.empirical import linearize
from neural_tangents.utils.empirical import taylor_expand
from neural_tangents.utils.monte_carlo import monte_carlo_kernel_fn
from . import predict
from . import stax
from ._src.batching import batch
from ._src.empirical import empirical_kernel_fn
from ._src.empirical import empirical_nngp_fn
from ._src.empirical import empirical_ntk_fn
from ._src.empirical import linearize
from ._src.empirical import taylor_expand
from ._src.monte_carlo import monte_carlo_kernel_fn
File renamed without changes.
20 changes: 10 additions & 10 deletions neural_tangents/utils/batch.py → neural_tangents/_src/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
from jax.tree_util import tree_all
from jax.tree_util import tree_map
from jax.tree_util import tree_multimap, tree_flatten, tree_unflatten
from neural_tangents.utils.kernel import Kernel
from neural_tangents.utils import utils
from neural_tangents.utils.typing import KernelFn, NTTree
from .utils.kernel import Kernel
from .utils import utils
from .utils.typing import KernelFn, NTTree

import numpy as onp

Expand All @@ -79,15 +79,18 @@ def batch(kernel_fn: KernelFn,
`kernel_fn(x1, x2, *args, **kwargs)`. Here `x1` and `x2` are
`np.ndarray`s of shapes `(n1,) + input_shape` and `(n2,) + input_shape`.
The kernel function should return a `PyTree`.
batch_size:
specifies the size of each batch that gets processed per physical device.
Because we parallelize the computation over columns it should be the case
that `x1.shape[0]` is divisible by `device_count * batch_size` and
`x2.shape[0]` is divisible by `batch_size`.
device_count:
specifies the number of physical devices to be used. If
`device_count == -1` all devices are used. If `device_count == 0`, no
device parallelism is used (a single default device is used).
store_on_device:
specifies whether the output should be kept on device or brought back to
CPU RAM as it is computed. Defaults to `True`. Set to `False` to store
Expand Down Expand Up @@ -249,7 +252,6 @@ def _flatten_kernel(k: Kernel,
def _reshape_kernel_for_pmap(k: Kernel,
device_count: int,
n1_per_device: int) -> Kernel:
# pytype: disable=attribute-error
cov2 = k.cov2
if cov2 is None:
cov2 = k.cov1
Expand Down Expand Up @@ -283,7 +285,6 @@ def _set_cov2_to_none(
if isinstance(k, Kernel):
k = k.replace(cov2=None)
return k
# pytype: enable=attribute-error


def _serial(kernel_fn: KernelFn,
Expand Down Expand Up @@ -444,8 +445,7 @@ def col_fn(n1, n2):
in_kernel = slice_kernel(k, n1_slice, n2_slice)
return (n1, kwargs1), kernel_fn(in_kernel, *args, **kwargs_merge)

cov2_is_none = utils.nt_tree_fn(reduce=lambda k: all(k))(lambda k:
k.cov2 is None)(k)
cov2_is_none = utils.nt_tree_fn(reduce=all)(lambda k: k.cov2 is None)(k)
_, k = _scan(row_fn, 0, (n1s, kwargs_np1))
if cov2_is_none:
k = _set_cov2_to_none(k)
Expand Down Expand Up @@ -520,7 +520,7 @@ def _check_dropout(n1, n2, kwargs):
'Using `serial` (i.e. use a non-zero batch_size in the '
'`batch` function.) could enforce square batch size in each device.')

def _get_n_per_device(n1, n2):
def _get_n_per_device(n1):
_device_count = device_count

n1_per_device, ragged = divmod(n1, device_count)
Expand Down Expand Up @@ -549,7 +549,7 @@ def get_batch_size(x):
n2 = n1 if x2_is_none else get_batch_size(x2)

_check_dropout(n1, n2, kwargs)
n1_per_device, _device_count = _get_n_per_device(n1, n2)
n1_per_device, _device_count = _get_n_per_device(n1)

_kernel_fn = _jit_or_pmap_broadcast(kernel_fn, _device_count)

Expand Down Expand Up @@ -579,7 +579,7 @@ def get_batch_sizes(k):

n1, n2 = get_batch_sizes(kernel)
_check_dropout(n1, n2, kwargs)
n1_per_device, _device_count = _get_n_per_device(n1, n2)
n1_per_device, _device_count = _get_n_per_device(n1)

_kernel_fn = _jit_or_pmap_broadcast(kernel_fn, _device_count)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
All functions in this module are applicable to any JAX functions of proper
signatures (not only those from `nt.stax`).
NNGP and NTK are computed using `empirical_nngp_fn`, `empirical_ntk_fn`, or
`empirical_kernel_fn` (for both). The kernels have a very specific output shape
convention that may be unexpected. Further, NTK has multiple implementations
that may perform differently depending on the task. Please read individual
functions' docstrings.
NNGP and NTK are computed using `empirical_nngp_fn`, `nt.empirical_ntk_fn`, or
`nt.empirical_kernel_fn` (for both). The kernels have a very specific output
shape convention that may be unexpected. Further, NTK has multiple
implementations that may perform differently depending on the task. Please read
individual functions' docstrings.
Example:
>>> from jax import random
Expand Down Expand Up @@ -49,18 +49,18 @@
>>> # Default setting: reducing over logits; pass `vmap_axes=0` because the
>>> # network is iid along the batch axis, no BatchNorm. Use default
>>> # `implementation=1` since the network has few trainable parameters.
>>> kernel_fn = nt.empirical_kernel_fn(f, trace_axes=(-1,),
>>> vmap_axes=0, implementation=1)
>>> kernel_fn = nt.empirical_kernel_fn(
>>> f, trace_axes=(-1,), vmap_axes=0, implementation=1)
>>>
>>> # (5, 20) np.ndarray test-train NNGP/NTK
>>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params)
>>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
>>> nngp_test_train = empirical_kernel_fn(x_test, x_train, 'nngp', params)
>>> ntk_test_train = empirical_kernel_fn(x_test, x_train, 'ntk', params)
>>>
>>> # Full kernel: not reducing over logits.
>>> kernel_fn = nt.empirical_kernel_fn(f, trace_axes=(), vmap_axes=0)
>>>
>>> # (5, 20, 10, 10) np.ndarray test-train NNGP/NTK namedtuple.
>>> k_test_train = kernel_fn(x_test, x_train, params)
>>> k_test_train = empirical_kernel_fn(x_test, x_train, params)
>>>
>>> # A wide FCN with lots of parameters
>>> init_fn, f, _ = stax.serial(
Expand All @@ -79,22 +79,22 @@
>>> ntk_fn = nt.empirical_ntk_fn(f, vmap_axes=0, implementation=2)
>>>
>>> # (5, 5) np.ndarray test-test NTK
>>> ntk_test_train = ntk_fn(x_test, None, params)
>>> ntk_test_train = empirical_ntk_fn(x_test, None, params)
>>>
>>> # Compute only output variances:
>>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,))
>>>
>>> # (20,) np.ndarray train-train diagonal NNGP
>>> nngp_train_train_diag = nngp_fn(x_train, None, params)
>>> nngp_train_train_diag = empirical_nngp_fn(x_train, None, params)
"""

import operator
from typing import Union, Callable, Optional, Tuple, Dict
from jax import eval_shape, jacobian, jvp, vjp, vmap, linear_transpose
import jax.numpy as np
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_map
from neural_tangents.utils import utils
from neural_tangents.utils.typing import ApplyFn, EmpiricalKernelFn, NTTree, PyTree, Axes, VMapAxes, VMapAxisTriple
from .utils import utils
from .utils.typing import ApplyFn, EmpiricalKernelFn, NTTree, PyTree, Axes, VMapAxes, VMapAxisTriple


def linearize(f: Callable[..., PyTree],
Expand Down Expand Up @@ -589,22 +589,22 @@ def empirical_ntk_fn(f: ApplyFn,
vmap_axes=vmap_axes)

if implementation == 1:
return _empirical_direct_ntk_fn(**kwargs)
return _direct_ntk_fn(**kwargs)

if implementation == 2:
return _empirical_implicit_ntk_fn(**kwargs)
return _implicit_ntk_fn(**kwargs)

raise ValueError(implementation)


def _empirical_implicit_ntk_fn(f: ApplyFn,
trace_axes: Axes = (-1,),
diagonal_axes: Axes = (),
vmap_axes: VMapAxes = None
) -> Callable[[NTTree[np.ndarray],
Optional[NTTree[np.ndarray]],
PyTree],
NTTree[np.ndarray]]:
def _implicit_ntk_fn(f: ApplyFn,
trace_axes: Axes = (-1,),
diagonal_axes: Axes = (),
vmap_axes: VMapAxes = None
) -> Callable[[NTTree[np.ndarray],
Optional[NTTree[np.ndarray]],
PyTree],
NTTree[np.ndarray]]:
"""Compute NTK implicitly without instantiating full Jacobians."""

def ntk_fn(x1: NTTree[np.ndarray],
Expand Down Expand Up @@ -688,14 +688,14 @@ def delta_vjp(delta):
return ntk_fn


def _empirical_direct_ntk_fn(f: ApplyFn,
trace_axes: Axes = (-1,),
diagonal_axes: Axes = (),
vmap_axes: VMapAxes = None
) -> Callable[[NTTree[np.ndarray],
Optional[NTTree[np.ndarray]],
PyTree],
NTTree[np.ndarray]]:
def _direct_ntk_fn(f: ApplyFn,
trace_axes: Axes = (-1,),
diagonal_axes: Axes = (),
vmap_axes: VMapAxes = None
) -> Callable[[NTTree[np.ndarray],
Optional[NTTree[np.ndarray]],
PyTree],
NTTree[np.ndarray]]:
"""Compute NTK by directly instantiating Jacobians and contracting."""

@utils.nt_tree_fn(tree_structure_argnum=0)
Expand Down
Loading

0 comments on commit 4b183ce

Please sign in to comment.