Skip to content

Commit

Permalink
Minor cleanup/refactoring:
Browse files Browse the repository at this point in the history
- Remove some unnecessary cross-dependencies between utils/test_utils/Kernel
- Fix a minor bug in batching over tree kernels when `x2 is None`, and make the accidentally omitted respective test case run.
- refactor some duplicate code.
- fix typos.

PiperOrigin-RevId: 426949291
  • Loading branch information
romanngg committed Feb 8, 2022
1 parent 7f75cae commit ae01c9f
Show file tree
Hide file tree
Showing 11 changed files with 1,674 additions and 1,658 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ set -e; for f in tests/*.py; do python $f; done

<b>See this [Colab](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/neural_tangents_cookbook.ipynb) for a detailed tutorial. Below is a very quick introduction.</b>

Our library closely follows JAX's API for specifying neural networks, [`stax`](https://github.com/google/jax/blob/main/jax/experimental/stax.py). In `stax` a network is defined by a pair of functions `(init_fn, apply_fn)` initializing the trainable parameters and computing the outputs of the network respectively. Below is an example of defining a 3-layer network and computing it's outputs `y` given inputs `x`.
Our library closely follows JAX's API for specifying neural networks, [`stax`](https://github.com/google/jax/blob/main/jax/experimental/stax.py). In `stax` a network is defined by a pair of functions `(init_fn, apply_fn)` initializing the trainable parameters and computing the outputs of the network respectively. Below is an example of defining a 3-layer network and computing its outputs `y` given inputs `x`.

```python
from jax import random
Expand Down Expand Up @@ -127,7 +127,7 @@ x2 = random.normal(key2, (20, 100))
kernel = kernel_fn(x1, x2, 'nngp')
```

Note that `kernel_fn` can compute _two_ covariance matrices corresponding to the Neural Network Gaussian Process (NNGP) and Neural Tangent (NT) kernels respectively. The NNGP kernel corresponds to the _Bayesian_ infinite neural network [[1-5]](#5-deep-neural-networks-as-gaussian-processes). The NTK corresponds to the _(continuous) gradient descent trained_ infinite network [[10]](#10-neural-tangent-kernel-convergence-and-generalization-in-neural-networks). In the above example, we compute the NNGP kernel but we could compute the NTK or both:
Note that `kernel_fn` can compute _two_ covariance matrices corresponding to the Neural Network Gaussian Process (NNGP) and Neural Tangent (NT) kernels respectively. The NNGP kernel corresponds to the _Bayesian_ infinite neural network [[1-5]](#5-deep-neural-networks-as-gaussian-processes). The NTK corresponds to the _(continuous) gradient descent trained_ infinite network [[10]](#10-neural-tangent-kernel-convergence-and-generalization-in-neural-networks). In the above example, we compute the NNGP kernel, but we could compute the NTK or both:

```python
# Get kernel of a single type
Expand Down Expand Up @@ -270,7 +270,7 @@ For this, we provide two convenient functions:
* `nt.linearize`, and
* `nt.taylor_expand`,

which allow to linearize or get an arbitrary-order Taylor expansion of any function `apply_fn(params, x)` around some initial parameters `params_0` as `apply_fn_lin = nt.linearize(apply_fn, params_0)`.
which allow us to linearize or get an arbitrary-order Taylor expansion of any function `apply_fn(params, x)` around some initial parameters `params_0` as `apply_fn_lin = nt.linearize(apply_fn, params_0)`.

One can use `apply_fn_lin(params, x)` exactly as you would any other function
(including as an input to JAX optimizers). This makes it easy to compare the
Expand Down Expand Up @@ -348,7 +348,7 @@ dependent. However, some rules of thumb that we've observed are:
agreement by the time the layer-width is 512 (RMSE of about 0.05 at the
end of training).

* For convolutional networks one generally observes reasonable agreement
* For convolutional networks one generally observes reasonable
agreement by the time the number of channels is 512.

2. Convergence at small learning rates.
Expand Down
2 changes: 1 addition & 1 deletion neural_tangents/LICENSE_SHORT
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
18 changes: 9 additions & 9 deletions neural_tangents/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def agg(k, diagonal_batch, s1, r1, s2, r2):
k = k.replace(nngp=nngp, ntk=ntk, cov1=cov1, cov2=cov2)

else:
raise ValueError(f'Unregocnized `implementation == {implementation}.')
raise ValueError(f'Unrecognized `implementation == {implementation}.')

return k

Expand Down Expand Up @@ -1632,7 +1632,7 @@ def FanInSum() -> InternalLayer:
init_fn, apply_fn = ostax.FanInSum

def kernel_fn(ks: Kernels, **kwargs) -> Kernel:
ks, is_reversed = _proprocess_kernels_for_fan_in(ks)
ks, is_reversed = _preprocess_kernels_for_fan_in(ks)
if not all([k.shape1 == ks[0].shape1 and
k.shape2 == ks[0].shape2 for k in ks[1:]]):
raise ValueError('All shapes should be equal in `FanInSum/FanInProd`, '
Expand Down Expand Up @@ -1686,7 +1686,7 @@ def FanInProd() -> InternalLayer:
"""Layer construction function for a fan-in product layer.
This layer takes a number of inputs (e.g. produced by `FanOut`) and
elementwisely multiply the inputs to produce a single output.
elementwise multiplies the inputs to produce a single output.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
Expand All @@ -1697,7 +1697,7 @@ def apply_fn(params, inputs, **kwargs):
return functools.reduce(np.multiply, inputs)

def kernel_fn(ks: Kernels, **kwargs) -> Kernel:
ks, is_reversed = _proprocess_kernels_for_fan_in(ks)
ks, is_reversed = _preprocess_kernels_for_fan_in(ks)
if not all([k.shape1 == ks[0].shape1 and
k.shape2 == ks[0].shape2 for k in ks[1:]]):
raise ValueError('All shapes should be equal in `FanInProd`.')
Expand Down Expand Up @@ -1762,7 +1762,7 @@ def FanInConcat(axis: int = -1) -> InternalLayer:
init_fn, apply_fn = ostax.FanInConcat(axis)

def kernel_fn(ks: Kernels, **kwargs) -> Kernel:
ks, is_reversed = _proprocess_kernels_for_fan_in(ks)
ks, is_reversed = _preprocess_kernels_for_fan_in(ks)

diagonal_batch = ks[0].diagonal_batch
diagonal_spatial = ks[0].diagonal_spatial
Expand Down Expand Up @@ -3901,7 +3901,7 @@ def apply_fn(params, x, **kwargs):
precision=precision)

def mask_fn(mask, input_shape):
# Interploation (except for "NEAREST") is done in float format:
# Interpolation (except for "NEAREST") is done in float format:
# https://github.com/google/jax/issues/3811. Float converted back to bool
# rounds up all non-zero elements to `True`, so naively resizing the `mask`
# will mark any output that has at least one contribution from a masked
Expand All @@ -3920,7 +3920,7 @@ def mask_fn(mask, input_shape):
# >>> DeviceArray([[ True, True],
# >>> [ True, True]], dtype=bool)
#
# Therefore, througout `stax` we rather follow the convention of marking
# Therefore, throughout `stax` we rather follow the convention of marking
# outputs as masked if they _only_ have contributions from masked elements
# (in other words, we don't let the mask destroy information; let content
# have preference over mask). For this we invert the mask before and after
Expand Down Expand Up @@ -4271,7 +4271,7 @@ def _cov(
channel_axis: Specifies which axis is the channel / feature axis.
For `kernel_fn`, channel size is considered to be infinite.
Returns:
Matrix of uncentred batch covariances with shape
Matrix of uncentered batch covariances with shape
`(batch_size_1, batch_size_2, <S spatial dimensions>)`
if `diagonal_spatial` is `True`, or
`(batch_size_1, batch_size_2, <2*S spatial dimensions>)`
Expand Down Expand Up @@ -4789,7 +4789,7 @@ def _affine(
return mat


def _proprocess_kernels_for_fan_in(ks: Kernels) -> Tuple[List[Kernel], bool]:
def _preprocess_kernels_for_fan_in(ks: Kernels) -> Tuple[List[Kernel], bool]:
# Check diagonal requirements.
if not all(k.diagonal_batch == ks[0].diagonal_batch and
k.diagonal_spatial == ks[0].diagonal_spatial and
Expand Down
11 changes: 7 additions & 4 deletions neural_tangents/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,11 @@ def _reshape_kernel_for_pmap(k: Kernel,


@utils.nt_tree_fn()
def _set_cov2_is_none(k: Kernel) -> Kernel:
return k.replace(cov2=None)
def _set_cov2_to_none(
k: Union[Kernel, np.ndarray]) -> Union[Kernel, np.ndarray]:
if isinstance(k, Kernel):
k = k.replace(cov2=None)
return k
# pytype: enable=attribute-error


Expand Down Expand Up @@ -444,7 +447,7 @@ def col_fn(n1, n2):
k.cov2 is None)(k)
_, k = _scan(row_fn, 0, (n1s, kwargs_np1))
if cov2_is_none:
k = _set_cov2_is_none(k)
k = _set_cov2_to_none(k)
return flatten(k, cov2_is_none)

@utils.wraps(kernel_fn)
Expand Down Expand Up @@ -584,7 +587,7 @@ def get_batch_sizes(k):
kernel = _reshape_kernel_for_pmap(kernel, _device_count, n1_per_device)
kernel = _kernel_fn(kernel, *args, **kwargs)
if cov2_is_none:
kernel = _set_cov2_is_none(kernel)
kernel = _set_cov2_to_none(kernel)
return _flatten_kernel(kernel, cov2_is_none, True)

@utils.wraps(kernel_fn)
Expand Down
19 changes: 17 additions & 2 deletions neural_tangents/utils/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,26 @@ class Kernel:
nngp:
covariance between the first and second batches (NNGP). A `np.ndarray` of
shape
`(batch_size_1, batch_size_2, height, [height,], width, [width,], ..))`,
`(batch_size_1, batch_size_2, height, [height,], width, [width,], ...))`,
where exact shape depends on `diagonal_spatial`.
ntk:
the neural tangent kernel (NTK). `np.ndarray` of same shape as `nngp`.
cov1:
covariance of the first batch of inputs. A `np.ndarray` with shape
`(batch_size_1, [batch_size_1,] height, [height,], width, [width,], ..)`
`(batch_size_1, [batch_size_1,] height, [height,], width, [width,], ...)`
where exact shape depends on `diagonal_batch` and `diagonal_spatial`.
cov2:
optional covariance of the second batch of inputs. A `np.ndarray` with
shape
`(batch_size_2, [batch_size_2,] height, [height,], width, [width,], ...)`
where the exact shape depends on `diagonal_batch` and `diagonal_spatial`.
x1_is_x2:
a boolean specifying whether `x1` and `x2` are the same.
is_gaussian:
a boolean, specifying whether the output features or channels of the layer
/ NN function (returning this `Kernel` as the `kernel_fn`) are i.i.d.
Expand All @@ -55,22 +60,26 @@ class Kernel:
an input through a CNN layer with i.i.d. Gaussian weights and biases
produces i.i.d. Gaussian random variables along the channel dimension,
while passing an input through a nonlinearity does not.
is_reversed:
a boolean specifying whether the covariance matrices `nngp`, `cov1`,
`cov2`, and `ntk` have the ordering of spatial dimensions reversed.
Ignored unless `diagonal_spatial` is `False`. Used internally to avoid
self-cancelling transpositions in a sequence of CNN layers that flip the
order of kernel spatial dimensions.
is_input:
a boolean specifying whether the current layer is the input layer and it
is used to avoid applying dropout to the input layer.
diagonal_batch:
a boolean specifying whether `cov1` and `cov2` store only the diagonal of
the sample-sample covariance (`diagonal_batch == True`,
`cov1.shape == (batch_size_1, ...)`), or the full covariance
(`diagonal_batch == False`,
`cov1.shape == (batch_size_1, batch_size_1, ...)`). Defaults to `True` as
no current layers require the full covariance.
diagonal_spatial:
a boolean specifying whether all (`cov1`, `ntk`, etc.) covariance matrices
store only the diagonals of the location-location covariances
Expand All @@ -81,18 +90,23 @@ class Kernel:
depth, depth, ...)`). Defaults to `False`, but is set to `True` if the
output top-layer covariance depends only on the diagonals (e.g. when a CNN
network has no pooling layers and `Flatten` on top).
shape1:
a tuple specifying the shape of the random variable in the first batch of
inputs. These have covariance `cov1` and covariance with the second batch
of inputs given by `nngp`.
shape2:
a tuple specifying the shape of the random variable in the second batch of
inputs. These have covariance `cov2` and covariance with the first batch
of inputs given by `nngp`.
batch_axis:
the batch axis of the activations.
channel_axis:
channel axis of the activations (taken to infinity).
mask1:
an optional boolean `np.ndarray` with a shape broadcastable to `shape1`
(and the same number of dimensions). `True` stands for the input being
Expand All @@ -101,6 +115,7 @@ class Kernel:
images), a `mask1` of shape `(5, 1, 32, 1)` means different images can
have different blocked columns (`H` and `C` dimensions are always either
both blocked or unblocked). `None` means no masking.
mask2:
same as `mask1`, but for the second batch of inputs.
"""
Expand Down
19 changes: 8 additions & 11 deletions neural_tangents/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from jax.lib import xla_bridge
import jax.numpy as np
import jax.test_util as jtu
from .kernel import Kernel
from neural_tangents.utils import utils
import numpy as onp


Expand Down Expand Up @@ -78,7 +76,6 @@ def _log(relative_error, absolute_error, expected, actual, did_pass):


def assert_close_matrices(self, expected, actual, rtol, atol=0.1):
@utils.nt_tree_fn()
def assert_close(expected, actual):
self.assertEqual(expected.shape, actual.shape)
relative_error = (
Expand All @@ -104,7 +101,7 @@ def assert_close(expected, actual):
else:
_log(relative_error, absolute_error, expected, actual, True)

assert_close(expected, actual)
jax.tree_map(assert_close, expected, actual)


class NeuralTangentsTestCase(jtu.JaxTestCase):
Expand All @@ -124,15 +121,15 @@ def assert_close(x, y):
x, y, check_dtypes=check_dtypes, atol=atol, rtol=rtol,
canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg)

if isinstance(x, Kernel):
self.assertIsInstance(y, Kernel)
for field in dataclasses.fields(Kernel):
name = field.name
x_name, y_name = getattr(x, name), getattr(y, name)
if dataclasses.is_dataclass(x):
self.assertIs(type(y), type(x))
for field in dataclasses.fields(x):
key = field.name
x_value, y_value = getattr(x, key), getattr(y, key)
is_pytree_node = field.metadata.get('pytree_node', True)
if is_pytree_node:
assert_close(x_name, y_name)
assert_close(x_value, y_value)
else:
self.assertEqual(x_name, y_name, name)
self.assertEqual(x_value, y_value, key)
else:
assert_close(x, y)
3 changes: 1 addition & 2 deletions neural_tangents/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from jax.lib import xla_bridge
import jax.numpy as np
from jax.tree_util import tree_all, tree_map
from .kernel import Kernel
import numpy as onp


Expand Down Expand Up @@ -157,7 +156,7 @@ def _output_to_dict(output):
if isinstance(output, dict):
return output

if isinstance(output, Kernel):
if hasattr(output, 'asdict'):
return output.asdict()

if hasattr(output, '_asdict'):
Expand Down
Loading

0 comments on commit ae01c9f

Please sign in to comment.