Skip to content

Commit

Permalink
Add an Index layer mimicking numpy indexing.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 460343414
  • Loading branch information
romanngg committed Jul 12, 2022
1 parent 083e393 commit 6f10d16
Show file tree
Hide file tree
Showing 6 changed files with 376 additions and 15 deletions.
5 changes: 3 additions & 2 deletions docs/stax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ Pointwise nonlinear layers.
Sin


Helper enums
Helper classes
--------------------------------------
Enums for specifying layer properties. Strings can be used in their place.
Utility classes for specifying layer properties. For enums, strings can be passed in their place.

.. autosummary::
:toctree: _autosummary
Expand All @@ -104,6 +104,7 @@ Enums for specifying layer properties. Strings can be used in their place.
AttentionMechanism
Padding
PositionalEmbedding
Slice


For developers
Expand Down
100 changes: 87 additions & 13 deletions neural_tangents/_src/stax/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,9 +1596,9 @@ def kernel_fn(k: Kernel, **kwargs) -> Kernel:
window_shape_kernel = window_shape
strides_kernel = strides
else:
window_shape_kernel = _double_tuple(
window_shape_kernel = utils.double_tuple(
window_shape[::(-1 if k.is_reversed else 1)])
strides_kernel = _double_tuple(strides[::(-1 if k.is_reversed else 1)])
strides_kernel = utils.double_tuple(strides[::(-1 if k.is_reversed else 1)])

def pool(mat, batch_ndim):
if mat is None or mat.ndim == 0:
Expand Down Expand Up @@ -2785,6 +2785,87 @@ def resize(k, shape1, shape2, diagonal_batch):
return init_fn, apply_fn, kernel_fn, mask_fn


@layer
@supports_masking(remask_kernel=False)
def Index(
idx: utils.SliceType,
batch_axis: int = 0,
channel_axis: int = -1
) -> InternalLayerMasked:
"""Index into the array mimicking :cls:`onp.ndarray` indexing.
.. warning::
Two limitations in the kernel regime (`kernel_fn`): the `channel_axis`
(infinite width) cannot be indexed, and the `batch_axis` can only be
indexed with tuples/slices, but not integers, since the library requires
there always to be a batch axis in a `Kernel`.
Args:
idx:
a `slice` object that would result from indexing an array as `x[idx]`.
To create this object, use the helper class :cls:`Slice`, i.e. pass
`idx=stax.Slice[1:10, :, ::-1]` (which is equivalent to passing an
explicit `idx=(slice(1, 10, None), slice(None), slice(None, None, -1)`.
batch_axis:
batch axis for `inputs`. Defaults to `0`, the leading axis.
channel_axis:
channel axis for `inputs`. Defaults to `-1`, the trailing axis. For
`kernel_fn`, channel size is considered to be infinite.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
Example:
>>> from neural_tangents import stax
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.serial(
>>> stax.Conv(128, (3, 3)),
>>> stax.Relu(),
>>> # Select every other element from the batch (leading axis), cropped
>>> # to the upper-left 4x4 corner.
>>> stax.Index(idx=stax.Slice[::2, :4, :4])
>>> stax.Conv(128, (2, 2)),
>>> stax.Relu(),
>>> # Select the first row. Notice that the image becomes 1D.
>>> stax.Index(idx=stax.Slice[:, 0, ...])
>>> stax.Conv(128, (2,))
>>> stax.GlobalAvgPool(),
>>> stax.Dense(10)
>>> )
"""
def init_fn(rng, input_shape):
return utils.slice_shape(input_shape, idx), ()

def apply_fn(params, x, **kwargs):
return x[idx]

def mask_fn(mask, input_shape):
return mask[idx]

@requires(batch_axis=batch_axis, channel_axis=channel_axis)
def kernel_fn(k: Kernel, **kwargs) -> Kernel:
return k[idx]

return init_fn, apply_fn, kernel_fn, mask_fn


class _Slice:

def __getitem__(self, idx: utils.SliceType) -> utils.SliceType:
return idx


Slice = _Slice()
"""A helper object to pass the slicing index `idx` to the :obj:`Index` layer.
Since we cannot pass slice specifications like `1, :, 2:8:3` as function
arguments, pass `Slice[1, :, 2:8:3] == (1, slice(None), slice(2, 8, 3))`
instead.
"""


# INTERNAL UTILITIES


Expand Down Expand Up @@ -3004,8 +3085,8 @@ def _conv_kernel_full_spatial_shared(

if padding == Padding.CIRCULAR:
spatial_axes = tuple(range(batch_ndim, lhs.ndim))
total_filter_shape = _double_tuple(filter_shape)
total_strides = _double_tuple(strides)
total_filter_shape = utils.double_tuple(filter_shape)
total_strides = utils.double_tuple(strides)
lhs = _same_pad_for_filter_shape(lhs,
total_filter_shape,
total_strides,
Expand Down Expand Up @@ -3059,13 +3140,6 @@ def get_n_channels(batch_and_channels: int) -> int:
return out


_T = TypeVar('_T')


def _double_tuple(x: Iterable[_T]) -> Tuple[_T, ...]:
return tuple(v for v in x for _ in range(2))


def _conv_kernel_full_spatial_unshared(
lhs: Optional[np.ndarray],
filter_shape: Sequence[int],
Expand Down Expand Up @@ -3173,8 +3247,8 @@ def get_n_channels(batch_and_channels: int) -> int:

if padding == Padding.CIRCULAR:
spatial_axes = tuple(range(batch_ndim, out.ndim))
total_filter_shape = _double_tuple(filter_shape)
total_strides = _double_tuple(strides)
total_filter_shape = utils.double_tuple(filter_shape)
total_strides = utils.double_tuple(strides)
out_shape = eval_shape(lambda x: _pool_transpose(x,
total_filter_shape,
total_strides,
Expand Down
68 changes: 68 additions & 0 deletions neural_tangents/_src/utils/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,71 @@ def __neg__(self) -> 'Kernel':
return self

__pos__ = __neg__

def __getitem__(self, idx: utils.SliceType) -> 'Kernel':
idx = utils.canonicalize_idx(idx, len(self.shape1))

channel_idx = idx[self.channel_axis]
batch_idx = idx[self.batch_axis]

# Not allowing to index the channel axis.
if channel_idx != slice(None):
raise NotImplementedError(
f'Indexing into the (infinite) channel axis {self.channel_axis} not '
f'supported.'
)

# Removing the batch.
if isinstance(batch_idx, int):
raise NotImplementedError(
f'Indexing an axis with an integer index (e.g. `0` vs `(0,)` removes '
f'the respective axis. Neural Tangents requires there to always be a '
f'batch axis ({self.batch_axis}), so it cannot be indexed with '
f'integers (please use tuples or `slice` instead).'
)

spatial_idx = tuple(s for i, s in enumerate(idx) if i not in
(self.batch_axis, self.channel_axis))

if self.is_reversed:
spatial_idx = spatial_idx[::-1]

if not self.diagonal_spatial:
spatial_idx = utils.double_tuple(spatial_idx)

nngp_batch_slice = (batch_idx, batch_idx)
cov_batch_slice = (batch_idx,) if self.diagonal_batch else (batch_idx,) * 2

nngp_slice = nngp_batch_slice + spatial_idx
cov_slice = cov_batch_slice + spatial_idx

nngp = self.nngp[nngp_slice]
ntk = (self.ntk if (self.ntk is None or self.ntk.ndim == 0) else
self.ntk[nngp_slice]) # pytype: disable=attribute-error

cov1 = self.cov1[cov_slice]
cov2 = None if self.cov2 is None else self.cov2[cov_slice]

# Axes may shift if some indices are integers (and not tuples / slices).
channel_axis = self.channel_axis
batch_axis = self.batch_axis

for i, s in reversed(list(enumerate(idx))):
if isinstance(s, int):
if i < channel_axis:
channel_axis -= 1
if i < batch_axis:
batch_axis -= 1

return self.replace(
nngp=nngp,
ntk=ntk,
cov1=cov1,
cov2=cov2,
channel_axis=channel_axis,
batch_axis=batch_axis,
shape1=utils.slice_shape(self.shape1, idx),
shape2=utils.slice_shape(self.shape2, idx),
mask1=None if self.mask1 is None else self.mask1[idx],
mask2=None if self.mask2 is None else self.mask2[idx],
)
65 changes: 65 additions & 0 deletions neural_tangents/_src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,3 +583,68 @@ def split_kwargs(kwargs, x1=None, x2=None):
kwargs1[k] = kwargs2[k] = v

return kwargs1, kwargs2


_SingleSlice = Union[int, slice, type(Ellipsis)]


SliceType = Union[_SingleSlice, Tuple[_SingleSlice, ...]]
"""A type to specify a slice of an array.
For instance, when indexing `x[1, :, 2:8:3]` a slice tuple
`(1, slice(None), slice(2, 8, 3))` is created. But since slice functions cannot
accept slice specifications like `1, :, 2:8:3` as arguments, you must either
pass this object, or, for convenience, an :cls:`~neural_tangents.stax.Slice`
slice, such as `nt.stax.Slice[1, :, 2:8:3]`.
"""


def canonicalize_idx(
idx: SliceType,
ndim: int
) -> Tuple[Union[int, slice], ...]:
if idx is Ellipsis or isinstance(idx, (int, slice)):
idx = (idx,) + (slice(None),) * (ndim - 1)

for i, s in enumerate(idx):
if s is Ellipsis:
idx = idx[:i] + (slice(None),) * (ndim - len(idx) + 1) + idx[i + 1:]

idx += (slice(None),) * (ndim - len(idx))
return idx


def slice_shape(shape: Tuple[int, ...], idx: SliceType) -> Tuple[int, ...]:
# Keep `None` or negative-sized axes if they aren't indexed into.
canonical_idx = canonicalize_idx(idx, len(shape))

np_shape = list(shape)
unknown_axes = {}
n_ints = 0 # Keep track of vanishing axes due to integer indexing.

for a, (i, s) in enumerate(zip(canonical_idx, shape)):
if s < 0 or s is None:
if i == slice(None):
np_shape[a] = 0
unknown_axes[a - n_ints] = s
else:
raise ValueError(
f'Trying to index with {i} axis {a} of unknown size {s}. '
f'Please provide input shape {shape} with non-negative integer '
f'size at axis {a}.')

if isinstance(i, int):
n_ints += 1

out_shape = list(onp.empty(np_shape)[idx].shape)
for a, v in unknown_axes.items():
out_shape[a] = v

return tuple(out_shape)


_T = TypeVar('_T')


def double_tuple(x: Iterable[_T]) -> Tuple[_T, ...]:
return tuple(v for v in x for _ in range(2))
7 changes: 7 additions & 0 deletions neural_tangents/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
ConvTranspose,
Dense,
Identity,
Index,
DotGeneral,
Dropout,
Flatten,
Expand All @@ -123,6 +124,12 @@
)


# Helper object for the `Index` layer.
from ._src.stax.linear import (
Slice
)


# Branching layers.
from ._src.stax.branching import (
FanInConcat,
Expand Down
Loading

0 comments on commit 6f10d16

Please sign in to comment.