Skip to content

Commit

Permalink
Add a stex.Elementwise layer that allows simpler layer specificatio…
Browse files Browse the repository at this point in the history
…n by passing only scalar-valued `nngp_fn`, computing `ntk_fn` via autodiff (derived by Lechao Xiao xlc@google.com @SiuMath).

Apart from automatic NTK, also makes `nngp_fn` more accessible to users, who won't need to worry about our kernel layout and other technicalities, and only specify a scalar-valued function.

PiperOrigin-RevId: 397428785
  • Loading branch information
romanngg committed Sep 18, 2021
1 parent 204190c commit 25788a9
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 0 deletions.
169 changes: 169 additions & 0 deletions neural_tangents/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3354,6 +3354,110 @@ def kernel_fn(k: Kernel) -> Kernel:
return _elementwise(fn, 'Sign', kernel_fn)


@layer
@supports_masking(remask_kernel=True)
def Elementwise(
fn: Optional[Callable[[float], float]] = None,
nngp_fn: Optional[Callable[[float, float, float], float]] = None,
d_nngp_fn: Optional[Callable[[float, float, float], float]] = None
) -> InternalLayer:
"""Elementwise application of `fn` using provided `nngp_fn`.
Constructs a layer given only scalar-valued nonlinearity / activation
`fn` and the 2D integral `nngp_fn`. NTK function is derived automatically in
closed form from `nngp_fn`.
If you cannot provide the `nngp_fn`, see `nt.stax.ElementwiseNumerical` to use
numerical integration or `nt.monte_carlo.monte_carlo_kernel_fn` to use Monte
Carlo sampling.
If your function is implemented separately (e.g. `nt.stax.Relu` etc) it's best
to use the custom implementation, since it uses symbolically simplified
expressions that are more precise and numerically stable.
Example:
>>> fn = jax.scipy.special.erf # type: Callable[[float], float]
>>>
>>> def nngp_fn(cov12: float, var1: float, var2: float) -> float:
>>> prod = (1 + 2 * var1) * (1 + 2 * var2)
>>> return np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi
>>>
>>> # Use autodiff and vectorization to construct the layer:
>>> _, _, kernel_fn_auto = stax.Elementwise(fn, nngp_fn)
>>>
>>> # Use custom pre-derived expressions
>>> # (should be faster and more numerically stable):
>>> _, _, kernel_fn_stax = stax.Erf()
>>>
>>> kernel_fn_auto(x1, x2) == kernel_fn_stax(x1, x2) # usually `True`.
Args:
fn:
a scalar-input/valued function `fn : R -> R`, the activation /
nonlinearity. If `None`, invoking the finite width `apply_fn` will raise
an exception.
nngp_fn:
a scalar-valued function
`nngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)]`, where the
expectation is over bivariate normal `x1, x2` with variances `var1`,
`var2` and covarianve `cov12`. Needed for both NNGP and NTK calculation.
If `None`, invoking infinite width `kernel_fn` will raise an exception.
d_nngp_fn:
an optional scalar-valued function
`d_nngp_fn : (cov12, var1, var2) |-> E[fn'(x_1) * fn'(x_2)]` with the same
`x1, x2` distribution as in `nngp_fn`. If `None`, will be computed using
automatic differentiation as `d_nngp_fn = d(nngp_fn)/d(cov12)`, which may
lead to worse precision or numerical stability. `nngp_fn` and `d_nngp_fn`
are used to derive the closed-form expression for the NTK.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
Raises:
NotImplementedError: if a `fn`/`nngp_fn` is not provided, but `apply_fn`/
`kernel_fn` is called respectively.
"""
if fn is not None:
name = fn.__name__
elif nngp_fn is not None:
name = nngp_fn.__name__
else:
raise ValueError('No finite (`fn`) or infinite (`nngp_fn`) functions '
'provided, the layer will not do anything.')

if nngp_fn is None:
kernel_fn = None

else:
if d_nngp_fn is None:
warnings.warn(
'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of '
'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.')
d_nngp_fn = np.vectorize(grad(nngp_fn))

@_requires(diagonal_spatial=_Diagonal()) # pytype:disable=wrong-keyword-args
def kernel_fn(k: Kernel) -> Kernel:
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk

var1 = _get_diagonal(cov1, k.diagonal_batch, k.diagonal_spatial)
var2 = _get_diagonal(cov2, k.diagonal_batch, k.diagonal_spatial)

if ntk is not None:
ntk *= _vmap_2d(d_nngp_fn, nngp, var1, var2, False, k.diagonal_spatial)

nngp = _vmap_2d(nngp_fn, nngp, var1, var2, False, k.diagonal_spatial)
cov1 = _vmap_2d(
nngp_fn, cov1, var1, None, k.diagonal_batch, k.diagonal_spatial)
if cov2 is not None:
cov2 = _vmap_2d(
nngp_fn, cov2, var2, None, k.diagonal_batch, k.diagonal_spatial)
return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)

return _elementwise(fn, name, kernel_fn)


@layer
@supports_masking(remask_kernel=True)
def ElementwiseNumerical(
Expand Down Expand Up @@ -5158,6 +5262,71 @@ def _diag_mul(
return _diag_mul_full_spatial(x, factor, diagonal_batch)


def _vmap_2d(fn: Callable[[float, float, float], float],
cov12: np.ndarray,
var1: np.ndarray,
var2: Optional[np.ndarray],
diagonal_batch: bool,
diagonal_spatial: bool) -> np.ndarray:
"""Effectively a "2D vmap" of `fn(cov12, var1, var2)`.
Applicable for all possible kernel layouts.
Args:
fn:
scalar-valued, elementwise `fn(cov12, var1, var2)` function to apply.
cov12:
covariance tensor (`q12`), `nngp`/`ntk`/`cov1`/`cov2`, of shape
`(N1[, N2])`, `(N1[, N2], X, Y, ...)`, `(N1[, N2], X, X, Y, Y, ...)`
depending on `diagonal_batch`, `diagonal_spatial`, and the number of
spatial dimensions.
var1:
variance tensor (`q11`), has shape `(N1[, X, Y, ...])`.
var2:
variance tensor (`q22`), has shape `(N1[, X, Y, ...])`.
diagonal_batch:
`True` if `cov12` has only one batch dimension.
diagonal_spatial:
`True` if `cov12` has spatial dimensions appearing once (vs twice).
Returns:
Resulting array `[fn(cov12[i, j], var1[i], var2[j])]_{i j}`. Has the same
shape as `cov12`.
"""
batch_ndim = 1 if diagonal_batch else 2
start = 2 - batch_ndim
cov_end = batch_ndim if diagonal_spatial else cov12.ndim
_cov12 = utils.make_2d(cov12, start, cov_end)

var_end = 1 if diagonal_spatial else var1.ndim
var1 = var1.reshape(var1.shape[:start] + (-1,) + var1.shape[var_end:])
var2 = var1 if var2 is None else var2.reshape(var2.shape[:start] + (-1,) +
var2.shape[var_end:])

fn = vmap(
vmap(
np.vectorize(fn),
in_axes=(start, None, start),
out_axes=start
),
in_axes=(start, start, None),
out_axes=start
)
out = fn(_cov12, var1, var2) # type: np.ndarray
out_shape = (cov12.shape[:start] +
cov12.shape[start:cov_end:2] +
cov12.shape[start + 1:cov_end:2] +
cov12.shape[cov_end:])
out = out.reshape(out_shape)
out = utils.zip_axes(out, start, cov_end)
return out


# MASKING


Expand Down
84 changes: 84 additions & 0 deletions tests/stax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,90 @@ def test_rbf(self, same_inputs, model, get, gamma):
rbf_gamma=gamma)



class ElementwiseTest(test_utils.NeuralTangentsTestCase):

@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_{}_{}_n={}_diag_batch={}_spatial={}'.format(
phi[0].__name__, same_inputs, n, diagonal_batch,
diagonal_spatial),
'phi':
phi,
'same_inputs':
same_inputs,
'n':
n,
'diagonal_batch':
diagonal_batch,
'diagonal_spatial':
diagonal_spatial
} for phi in [
stax.Identity(),
stax.Erf(),
stax.Sin(),
stax.Relu(),
]
for same_inputs in [False, True, None]
for n in [0, 1, 2]
for diagonal_batch in [True, False]
for diagonal_spatial in [True, False]))
def test_elementwise(self, same_inputs, phi, n, diagonal_batch,
diagonal_spatial):
fn = lambda x: phi[1]((), x)

name = phi[0].__name__

def nngp_fn(cov12, var1, var2):
if 'Identity' in name:
res = cov12

elif 'Erf' in name:
prod = (1 + 2 * var1) * (1 + 2 * var2)
res = np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi

elif 'Sin' in name:
sum_ = (var1 + var2)
s1 = np.exp((-0.5 * sum_ + cov12))
s2 = np.exp((-0.5 * sum_ - cov12))
res = (s1 - s2) / 2

elif 'Relu' in name:
prod = var1 * var2
sqrt = stax._sqrt(prod - cov12 ** 2)
angles = np.arctan2(sqrt, cov12)
dot_sigma = (1 - angles / np.pi) / 2
res = sqrt / (2 * np.pi) + dot_sigma * cov12

else:
raise NotImplementedError(name)

return res

_, _, kernel_fn = stax.serial(stax.Dense(1), stax.Elementwise(fn, nngp_fn),
stax.Dense(1), stax.Elementwise(fn, nngp_fn))
_, _, kernel_fn_manual = stax.serial(stax.Dense(1), phi,
stax.Dense(1), phi)

key = random.PRNGKey(1)
shape = (4, 3, 2)[:n] + (1,)
x1 = random.normal(key, (5,) + shape)
if same_inputs is None:
x2 = None
elif same_inputs is True:
x2 = x1
else:
x2 = random.normal(key, (6,) + shape)

kwargs = dict(diagonal_batch=diagonal_batch,
diagonal_spatial=diagonal_spatial)

k = kernel_fn(x1, x2, **kwargs)
k_manual = kernel_fn_manual(x1, x2, **kwargs).replace(is_gaussian=False)
self.assertAllClose(k_manual, k)


class ElementwiseNumericalTest(test_utils.NeuralTangentsTestCase):

@jtu.parameterized.named_parameters(
Expand Down

0 comments on commit 25788a9

Please sign in to comment.