From eec18f56562864540f8141372d7df67f9c34373e Mon Sep 17 00:00:00 2001 From: Roman Novak Date: Fri, 24 Jun 2022 12:14:07 -0700 Subject: [PATCH] BEGIN_PUBLIC 1) Various documentation improvements: - use proper RST syntax for warning/see-also panes. - set default role to `code` globally to better render `code blocks` in auto-generated pages. - change the `code` CSS style. - better document enums. - add some missing public enums to `stax`. - use RST :roles: and intersphinx to create clickable references to NT and JAX/numpy/etc objects. - Make `Kernel` public, and remove a circular dependency in typing -> Kernel -> utils -> typing. - Fix interactive python `>>>` code blocks to not create extra `>>>` or spaces when copy-pasting. - Add citations and references to all papers that contributed to NT. - Add basic `experimental` documentation. - Minor typing and linting fixes. 2) Fix a bug in the empirical kernel of grouped CNN, and add respective tests. 3) Add more tests to Tensorflow NTK, comparing TF and JAX NTKs. 4) Reduce some sizes to deflake tests. Tests: http://sponge2/fa1e9c8a-6fde-45d9-b301-93c7c47d28d8 PiperOrigin-RevId: 457065540 --- CITATION | 32 + README.md | 39 +- docs/_static/style.css | 5 + docs/batching.rst | 4 +- docs/conf.py | 9 +- docs/empirical.rst | 10 +- docs/experimental.rst | 24 + docs/index.rst | 19 +- docs/kernel.rst | 6 +- docs/monte_carlo.rst | 4 +- docs/predict.rst | 2 +- docs/stax.rst | 5 +- docs/typing.rst | 2 +- examples/empirical_ntk.py | 3 + examples/experimental/empirical_ntk_tf.py | 10 +- neural_tangents/__init__.py | 8 +- neural_tangents/_src/batching.py | 44 +- neural_tangents/_src/empirical.py | 396 +++++++------ neural_tangents/_src/monte_carlo.py | 52 +- neural_tangents/_src/predict.py | 205 ++++--- neural_tangents/_src/stax/branching.py | 11 +- neural_tangents/_src/stax/combinators.py | 16 +- neural_tangents/_src/stax/elementwise.py | 35 +- neural_tangents/_src/stax/linear.py | 550 +++++++++++------- neural_tangents/_src/stax/requirements.py | 158 +++-- neural_tangents/_src/utils/dataclasses.py | 35 +- neural_tangents/_src/utils/kernel.py | 6 +- neural_tangents/_src/utils/rules.py | 22 +- neural_tangents/_src/utils/typing.py | 19 +- neural_tangents/_src/utils/utils.py | 42 +- .../experimental/empirical_tf/empirical.py | 230 ++++---- neural_tangents/stax.py | 93 +-- notebooks/empirical_ntk_resnet.ipynb | 1 + .../empirical_ntk_resnet_tf.ipynb | 3 +- setup.py | 12 +- tests/empirical_test.py | 247 ++++++-- tests/experimental/empirical_tf_test.py | 231 ++++++-- 37 files changed, 1650 insertions(+), 940 deletions(-) create mode 100644 docs/experimental.rst diff --git a/CITATION b/CITATION index f8deb372..58e5ea20 100644 --- a/CITATION +++ b/CITATION @@ -1,7 +1,39 @@ +# Infinite width NTK/NNGP: @inproceedings{neuraltangents2020, title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python}, author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha Sohl-Dickstein and Samuel S. Schoenholz}, booktitle={International Conference on Learning Representations}, year={2020}, + pdf={https://arxiv.org/abs/1912.02803}, + url={https://github.com/google/neural-tangents} +} + +# Finite width, empirical NTK/NNGP: +@inproceedings{novak2022fast, + title={Fast Finite Width Neural Tangent Kernel}, + author={Roman Novak and Jascha Sohl-Dickstein and Samuel S. Schoenholz}, + booktitle={International Conference on Machine Learning}, + year={2022}, + pdf={https://arxiv.org/abs/2206.08720}, + url={https://github.com/google/neural-tangents} +} + +# Attention and variable-length inputs: +@inproceedings{hron2020infinite, + title={Infinite attention: NNGP and NTK for deep attention networks}, + author={Jiri Hron and Yasaman Bahri and Jascha Sohl-Dickstein and Roman Novak}, + booktitle={International Conference on Machine Learning}, + year={2020}, + pdf={https://arxiv.org/abs/2006.10540}, + url={https://github.com/google/neural-tangents} +} + +# Infinite-width "standard" parameterization: +@misc{sohl2020on, + title={On the infinite width limit of neural networks with a standard parameterization}, + author={Jascha Sohl-Dickstein and Roman Novak and Samuel S. Schoenholz and Jaehoon Lee}, + publisher = {arXiv}, + year={2020}, + pdf={https://arxiv.org/abs/2001.07301}, url={https://github.com/google/neural-tangents} } diff --git a/README.md b/README.md index 39978caf..81a2159f 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,10 @@ Freedom of thought is fundamental to all of science. Right now, our freedom is being suppressed with carpet bombing of civilians in Ukraine. **Don't be against the war - fight against the war!** Support Ukraine at **[stopputin.net](https://stopputin.net/)**. +### News + +Our paper "[Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720)" has been accepted to ICML2022, and the respective code [submitted](https://github.com/romanngg/neural-tangents/commit/60b6c16758652f4526536409b5bc90602287b868) (available starting from version `0.6.0`). + # Neural Tangents [**ICLR 2020 Video**](https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html) | [**Paper**](https://arxiv.org/abs/1912.02803) @@ -60,6 +64,7 @@ An easy way to get started with Neural Tangents is by playing around with the fo - Empirical NTK: - [Fully-connected network](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_fcn.ipynb) - [FLAX ResNet18](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_resnet.ipynb) + - [Experimental: Tensorflow ResNet50](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/experimental/empirical_ntk_resnet_tf.ipynb) ## Installation @@ -528,14 +533,46 @@ to the list! ## Citation -If you use the code in a publication, please cite our ICLR 2020 paper: +If you use the code in a publication, please cite our papers: ``` +# Infinite width NTK/NNGP: @inproceedings{neuraltangents2020, title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python}, author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha Sohl-Dickstein and Samuel S. Schoenholz}, booktitle={International Conference on Learning Representations}, year={2020}, + pdf={https://arxiv.org/abs/1912.02803}, + url={https://github.com/google/neural-tangents} +} + +# Finite width, empirical NTK/NNGP: +@inproceedings{novak2022fast, + title={Fast Finite Width Neural Tangent Kernel}, + author={Roman Novak and Jascha Sohl-Dickstein and Samuel S. Schoenholz}, + booktitle={International Conference on Machine Learning}, + year={2022}, + pdf={https://arxiv.org/abs/2206.08720}, + url={https://github.com/google/neural-tangents} +} + +# Attention and variable-length inputs: +@inproceedings{hron2020infinite, + title={Infinite attention: NNGP and NTK for deep attention networks}, + author={Jiri Hron and Yasaman Bahri and Jascha Sohl-Dickstein and Roman Novak}, + booktitle={International Conference on Machine Learning}, + year={2020}, + pdf={https://arxiv.org/abs/2006.10540}, + url={https://github.com/google/neural-tangents} +} + +# Infinite-width "standard" parameterization: +@misc{sohl2020on, + title={On the infinite width limit of neural networks with a standard parameterization}, + author={Jascha Sohl-Dickstein and Roman Novak and Samuel S. Schoenholz and Jaehoon Lee}, + publisher = {arXiv}, + year={2020}, + pdf={https://arxiv.org/abs/2001.07301}, url={https://github.com/google/neural-tangents} } ``` diff --git a/docs/_static/style.css b/docs/_static/style.css index b07bdb1b..ebc06441 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -1,3 +1,8 @@ .wy-nav-content { max-width: none; } + +.rst-content code.literal, .rst-content tt.literal { + color: #404040; + white-space: normal +} diff --git a/docs/batching.rst b/docs/batching.rst index 1a1a44c3..42051cd3 100644 --- a/docs/batching.rst +++ b/docs/batching.rst @@ -1,11 +1,11 @@ :github_url: https://github.com/google/neural-tangents/tree/main/docs/batching.rst -.. default-role:: code + `nt.batch` -- using multiple devices ============================================================ -.. default-role:: code + .. automodule:: neural_tangents._src.batching .. automodule:: neural_tangents :noindex: diff --git a/docs/conf.py b/docs/conf.py index 6d318b81..4647f1bd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -80,6 +80,7 @@ 'python': ('https://docs.python.org/3/', None), 'numpy': ('https://docs.scipy.org/doc/numpy/', None), 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None), + 'jax': ('https://jax.readthedocs.io/en/latest/', None), } @@ -172,16 +173,20 @@ 'Miscellaneous'), ] + # add_module_names = False +default_role = 'code' + + def remove_module_docstring(app, what, name, obj, options, lines): - if what == "module" and name == "neural_tangents": + if what == 'module' and name == 'neural_tangents': del lines[:] def setup(app): - app.connect("autodoc-process-docstring", remove_module_docstring) + app.connect('autodoc-process-docstring', remove_module_docstring) app.add_css_file('style.css') diff --git a/docs/empirical.rst b/docs/empirical.rst index 4845e8b5..cac5fa59 100644 --- a/docs/empirical.rst +++ b/docs/empirical.rst @@ -1,6 +1,6 @@ :github_url: https://github.com/google/neural-tangents/tree/main/docs/empirical.rst -.. default-role:: code + `nt.empirical` -- finite NNGP and NTK ====================================== @@ -21,13 +21,9 @@ Finite-width NNGP and/or NTK kernel functions. NTK implementation -------------------------------------- -An `IntEnum` specifying NTK implementation method. - -.. autosummary:: - :toctree: _autosummary - - NtkImplementation +An :class:`enum.IntEnum` specifying NTK implementation method. +.. autoclass:: NtkImplementation Linearization and Taylor expansion -------------------------------------- diff --git a/docs/experimental.rst b/docs/experimental.rst new file mode 100644 index 00000000..e5bdc866 --- /dev/null +++ b/docs/experimental.rst @@ -0,0 +1,24 @@ +:github_url: https://github.com/google/neural-tangents/tree/main/docs/experimental.rst + + + +`nt.experimental` -- prototypes +====================================== + +.. warning:: + This module contains new highly-experimental prototypes. Please beware that they are not properly tested, not supported, and may suffer from sub-optimal performance. Use at your own risk! + +.. automodule:: neural_tangents.experimental +.. currentmodule:: neural_tangents.experimental + +Kernel functions +-------------------------------------- +Finite-width NTK kernel function *in Tensorflow*. See the `Python `_ and `Colab `_ usage examples. + +.. autofunction:: empirical_ntk_fn_tf + +Helper functions +-------------------------------------- +A helper function to convert Tensorflow stateful models into functional-style, stateless `apply_fn(params, x)` forward pass function and extract the respective `params`. + +.. autofunction:: get_apply_fn_and_params diff --git a/docs/index.rst b/docs/index.rst index 236a3a03..f65f37b4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,6 +1,6 @@ :github_url: https://github.com/google/neural-tangents/tree/main/docs/index.rst -.. default-role:: code + Neural Tangents Reference =========================================== @@ -24,6 +24,7 @@ neural networks (a.k.a. NTK, NNGP). kernel typing + experimental .. toctree:: :maxdepth: 2 @@ -34,15 +35,25 @@ neural networks (a.k.a. NTK, NNGP). Function Space Linearization Neural Network Phase Diagram Performance Benchmarks + Finite Width NTK + +.. toctree:: + :maxdepth: 2 + :caption: Papers: + + Neural Tangents: Fast and Easy Infinite Neural Networks in Python + Fast Finite Width Neural Tangent Kernel + Infinite attention: NNGP and NTK for deep attention networks + On the infinite width limit of neural networks with a standard parameterization .. toctree:: :maxdepth: 2 :caption: Other Resources: - GitHub - Paper - Video + Neural Tangents Video + Finite Width NTK Video Wikipedia + GitHub Indices and tables ================== diff --git a/docs/kernel.rst b/docs/kernel.rst index d656cdff..c8b75c79 100644 --- a/docs/kernel.rst +++ b/docs/kernel.rst @@ -1,9 +1,9 @@ :github_url: https://github.com/google/neural-tangents/tree/main/docs/kernel.rst -.. default-role:: code -`Kernel` dataclass + +:class:`~neural_tangents.Kernel` dataclass ============================================================= -.. automodule:: neural_tangents._src.utils.kernel +.. autoclass:: neural_tangents.Kernel :members: diff --git a/docs/monte_carlo.rst b/docs/monte_carlo.rst index 01b54f30..1a25f656 100644 --- a/docs/monte_carlo.rst +++ b/docs/monte_carlo.rst @@ -1,11 +1,11 @@ :github_url: https://github.com/google/neural-tangents/tree/main/docs/monte_carlo.rst -.. default-role:: code + `nt.monte_carlo_kernel_fn` - MC Sampling ========================================= -.. default-role:: code + .. automodule:: neural_tangents._src.monte_carlo .. automodule:: neural_tangents :members: monte_carlo_kernel_fn diff --git a/docs/predict.rst b/docs/predict.rst index e7a85279..c976105e 100644 --- a/docs/predict.rst +++ b/docs/predict.rst @@ -1,6 +1,6 @@ :github_url: https://github.com/google/neural-tangents/tree/main/docs/predict.rst -.. default-role:: code + `nt.predict` -- inference w/ NNGP & NTK ============================================================= diff --git a/docs/stax.rst b/docs/stax.rst index b9596e4c..04abab32 100644 --- a/docs/stax.rst +++ b/docs/stax.rst @@ -1,10 +1,11 @@ :github_url: https://github.com/google/neural-tangents/tree/main/docs/stax.rst -.. default-role:: code + `nt.stax` -- infinite NNGP and NTK =========================================== + .. automodule:: neural_tangents.stax @@ -99,6 +100,8 @@ Enums for specifying layer properties. Strings can be used in their place. .. autosummary:: :toctree: _autosummary + AggregateImplementation + AttentionMechanism Padding PositionalEmbedding diff --git a/docs/typing.rst b/docs/typing.rst index f7eb948d..b6abad08 100644 --- a/docs/typing.rst +++ b/docs/typing.rst @@ -1,6 +1,6 @@ :github_url: https://github.com/google/neural-tangents/tree/main/docs/typing.rst -.. default-role:: code + Typing ============================================================= diff --git a/examples/empirical_ntk.py b/examples/empirical_ntk.py index f19512dd..75851d41 100644 --- a/examples/empirical_ntk.py +++ b/examples/empirical_ntk.py @@ -16,6 +16,9 @@ All implementations apply to any differentiable functions, (not necessarily ones constructed with Neural Tangents). + +For details about the empirical (finite width) NTK computation, please see +"`Fast Finite Width Neural Tangent Kernel `_". """ from absl import app diff --git a/examples/experimental/empirical_ntk_tf.py b/examples/experimental/empirical_ntk_tf.py index 4c6b83db..a45676aa 100644 --- a/examples/experimental/empirical_ntk_tf.py +++ b/examples/experimental/empirical_ntk_tf.py @@ -12,7 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Minimal highly-experimental Tensorflow NTK example.""" +"""Minimal highly-experimental Tensorflow NTK example. + +Specifically, Tensorflow NTK appears to have very long compile times (but OK +runtime), is prone to triggering XLA errors, and does not distinguish between +trainable and non-trainable parameters of the model. + +For details about the empirical (finite width) NTK computation, please see +"`Fast Finite Width Neural Tangent Kernel `_". +""" from absl import app import neural_tangents as nt diff --git a/neural_tangents/__init__.py b/neural_tangents/__init__.py index 8d154679..d137f293 100644 --- a/neural_tangents/__init__.py +++ b/neural_tangents/__init__.py @@ -18,19 +18,15 @@ __version__ = '0.6.0' - +from . import experimental from . import predict from . import stax - from ._src.batching import batch - from ._src.empirical import empirical_kernel_fn from ._src.empirical import empirical_nngp_fn from ._src.empirical import empirical_ntk_fn from ._src.empirical import linearize from ._src.empirical import NtkImplementation from ._src.empirical import taylor_expand - from ._src.monte_carlo import monte_carlo_kernel_fn - -from . import experimental +from ._src.utils.kernel import Kernel diff --git a/neural_tangents/_src/batching.py b/neural_tangents/_src/batching.py index 249d5b00..7c9c7064 100644 --- a/neural_tangents/_src/batching.py +++ b/neural_tangents/_src/batching.py @@ -20,25 +20,25 @@ the result, allowing to both use multiple accelerators and stay within memory limits. -Note that you typically should not apply the `jax.jit` decorator to the +Note that you typically should not apply the :obj:`jax.jit` decorator to the resulting `batched_kernel_fn`, as its purpose is explicitly serial execution in -order to save memory. Further, you do not need to apply `jax.jit` to the input -`kernel_fn` function, as it is JITted internally. +order to save memory. Further, you do not need to apply :obj:`jax.jit` to the +input `kernel_fn` function, as it is JITted internally. Example: - >>> from jax import numpy as np - >>> import neural_tangents as nt - >>> from neural_tangents import stax - >>> - >>> # Define some kernel function. - >>> _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Relu(), stax.Dense(1)) - >>> - >>> # Compute the kernel in batches, in parallel. - >>> kernel_fn_batched = nt.batch(kernel_fn, device_count=-1, batch_size=5) - >>> - >>> # Generate dummy input data. - >>> x1, x2 = np.ones((40, 10)), np.ones((80, 10)) - >>> kernel_fn_batched(x1, x2) == kernel_fn(x1, x2) # True! + >>> from jax import numpy as np + >>> import neural_tangents as nt + >>> from neural_tangents import stax + >>> # + >>> # Define some kernel function. + >>> _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Relu(), stax.Dense(1)) + >>> # + >>> # Compute the kernel in batches, in parallel. + >>> kernel_fn_batched = nt.batch(kernel_fn, device_count=-1, batch_size=5) + >>> # + >>> # Generate dummy input data. + >>> x1, x2 = np.ones((40, 10)), np.ones((80, 10)) + >>> kernel_fn_batched(x1, x2) == kernel_fn(x1, x2) # True! """ @@ -137,7 +137,7 @@ def _scan(f: Callable[[_Carry, _Input], Tuple[_Carry, _Output]], xs: Iterable[_Input]) -> Tuple[_Carry, _Output]: """Implements an unrolled version of scan. - Based on `jax.lax.scan` and has a similar API. + Based on :obj:`jax.lax.scan` and has a similar API. TODO(schsam): We introduce this function because lax.scan currently has a higher peak memory usage than the unrolled version. We will aim to swap this @@ -353,8 +353,8 @@ def get_n1_n2(x1, x2): return n1, n2 n1, n2 = get_n1_n2(x1, x2) - (n1_batches, n1_batch_size, n2_batches, n2_batch_size) = \ - _get_n_batches_and_batch_sizes(n1, n2, batch_size, device_count) + (n1_batches, n1_batch_size, n2_batches, n2_batch_size) = ( + _get_n_batches_and_batch_sizes(n1, n2, batch_size, device_count)) @utils.nt_tree_fn(nargs=1) def batch_input(x, batch_count, batch_size): @@ -735,9 +735,9 @@ def f_pmapped(x_or_kernel: Union[np.ndarray, Kernel], *args, **kwargs): kwargs_other[k] = v # Check cache before jitting. - _key = key + \ - tuple(args_other.items()) + \ - tuple(kwargs_other.items()) + _key = key + ( + tuple(args_other.items()) + + tuple(kwargs_other.items())) if _key in cache: _f = cache[_key] else: diff --git a/neural_tangents/_src/empirical.py b/neural_tangents/_src/empirical.py index 6bc2ad03..bc19c9b1 100644 --- a/neural_tangents/_src/empirical.py +++ b/neural_tangents/_src/empirical.py @@ -15,86 +15,90 @@ """Compute empirical NNGP and NTK; approximate functions via Taylor series. All functions in this module are applicable to any JAX functions of proper -signatures (not only those from `nt.stax`). +signatures (not only those from :obj:`~neural_tangents.stax`). -NNGP and NTK are computed using `nt.empirical_nngp_fn`, `nt.empirical_ntk_fn`, -or `nt.empirical_kernel_fn` (for both). The kernels have a very specific output -shape convention that may be unexpected. Further, NTK has multiple -implementations that may perform differently depending on the task. Please read -individual functions' docstrings. +NNGP and NTK are computed using :obj:`~neural_tangents.empirical_nngp_fn`, +:obj:`~neural_tangents.empirical_ntk_fn`, or +:obj:`~neural_tangents.empirical_kernel_fn` (for both). The kernels have a very +specific output shape convention that may be unexpected. Further, NTK has +multiple implementations that may perform differently depending on the task. +Please read individual functions' docstrings. + +For details, please see "`Fast Finite Width Neural Tangent Kernel +`_". Example: - >>> from jax import random - >>> import neural_tangents as nt - >>> from neural_tangents import stax - >>> - >>> key1, key2, key3 = random.split(random.PRNGKey(1), 3) - >>> x_train = random.normal(key1, (20, 32, 32, 3)) - >>> y_train = random.uniform(key1, (20, 10)) - >>> x_test = random.normal(key2, (5, 32, 32, 3)) - >>> - >>> # A narrow CNN. - >>> init_fn, f, _ = stax.serial( - >>> stax.Conv(32, (3, 3)), - >>> stax.Relu(), - >>> stax.Conv(32, (3, 3)), - >>> stax.Relu(), - >>> stax.Conv(32, (3, 3)), - >>> stax.Flatten(), - >>> stax.Dense(10) - >>> ) - >>> - >>> _, params = init_fn(key3, x_train.shape) - >>> - >>> # Default setting: reducing over logits; pass `vmap_axes=0` because the - >>> # network is iid along the batch axis, no BatchNorm. Use default - >>> # `implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION` (`1`). - >>> kernel_fn = nt.empirical_kernel_fn( - >>> f, trace_axes=(-1,), vmap_axes=0, - >>> implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION) - >>> - >>> # (5, 20) np.ndarray test-train NNGP/NTK - >>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params) - >>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params) - >>> - >>> # Full kernel: not reducing over logits. Use structured derivatives - >>> # `implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) for - >>> # typically faster computation and lower memory cost. - >>> kernel_fn = nt.empirical_kernel_fn( - >>> f, trace_axes=(), vmap_axes=0, - >>> implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES) - >>> - >>> # (5, 20, 10, 10) np.ndarray test-train NNGP/NTK namedtuple. - >>> k_test_train = kernel_fn(x_test, x_train, None, params) - >>> - >>> # A wide FCN with lots of parameters and many (`100`) outputs. - >>> init_fn, f, _ = stax.serial( - >>> stax.Flatten(), - >>> stax.Dense(1024), - >>> stax.Relu(), - >>> stax.Dense(1024), - >>> stax.Relu(), - >>> stax.Dense(100) - >>> ) - >>> - >>> _, params = init_fn(key3, x_train.shape) - >>> - >>> # Use ntk-vector products - >>> # (`implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS`) since the - >>> # network has many parameters relative to the cost of forward pass, - >>> # large outputs. - >>> ntk_fn = nt.empirical_ntk_fn( + >>> from jax import random + >>> import neural_tangents as nt + >>> from neural_tangents import stax + >>> # + >>> key1, key2, key3 = random.split(random.PRNGKey(1), 3) + >>> x_train = random.normal(key1, (20, 32, 32, 3)) + >>> y_train = random.uniform(key1, (20, 10)) + >>> x_test = random.normal(key2, (5, 32, 32, 3)) + >>> # + >>> # A narrow CNN. + >>> init_fn, f, _ = stax.serial( + >>> stax.Conv(32, (3, 3)), + >>> stax.Relu(), + >>> stax.Conv(32, (3, 3)), + >>> stax.Relu(), + >>> stax.Conv(32, (3, 3)), + >>> stax.Flatten(), + >>> stax.Dense(10) + >>> ) + >>> # + >>> _, params = init_fn(key3, x_train.shape) + >>> # + >>> # Default setting: reducing over logits; pass `vmap_axes=0` because the + >>> # network is iid along the batch axis, no BatchNorm. Use default + >>> # `implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION` (`1`). + >>> kernel_fn = nt.empirical_kernel_fn( + >>> f, trace_axes=(-1,), vmap_axes=0, + >>> implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION) + >>> # + >>> # (5, 20) np.ndarray test-train NNGP/NTK + >>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params) + >>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params) + >>> # + >>> # Full kernel: not reducing over logits. Use structured derivatives + >>> # `implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) for + >>> # typically faster computation and lower memory cost. + >>> kernel_fn = nt.empirical_kernel_fn( + >>> f, trace_axes=(), vmap_axes=0, + >>> implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES) + >>> # + >>> # (5, 20, 10, 10) np.ndarray test-train NNGP/NTK namedtuple. + >>> k_test_train = kernel_fn(x_test, x_train, None, params) + >>> # + >>> # A wide FCN with lots of parameters and many (`100`) outputs. + >>> init_fn, f, _ = stax.serial( + >>> stax.Flatten(), + >>> stax.Dense(1024), + >>> stax.Relu(), + >>> stax.Dense(1024), + >>> stax.Relu(), + >>> stax.Dense(100) + >>> ) + >>> # + >>> _, params = init_fn(key3, x_train.shape) + >>> # + >>> # Use ntk-vector products + >>> # (`implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS`) since the + >>> # network has many parameters relative to the cost of forward pass, + >>> # large outputs. + >>> ntk_fn = nt.empirical_ntk_fn( >>> f, vmap_axes=0, >>> implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS) - >>> - >>> # (5, 5) np.ndarray test-test NTK - >>> ntk_test_test = ntk_fn(x_test, None, params) - >>> - >>> # Compute only output variances: - >>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,)) - >>> - >>> # (20,) np.ndarray train-train diagonal NNGP - >>> nngp_train_train_diag = nngp_fn(x_train, None, params) + >>> # + >>> # (5, 5) np.ndarray test-test NTK + >>> ntk_test_test = ntk_fn(x_test, None, params) + >>> # + >>> # Compute only output variances: + >>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,)) + >>> # + >>> # (20,) np.ndarray train-train diagonal NNGP + >>> nngp_train_train_diag = nngp_fn(x_train, None, params) """ import enum @@ -215,14 +219,15 @@ def empirical_nngp_fn( The Neural Network Gaussian Process (NNGP) kernel is defined as :math:`f(X_1) f(X_2)^T`, i.e. the outer product of the function outputs. - WARNING: resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` - subject to `trace_axes` and `diagonal_axes` parameters, which make certain - assumptions about the outputs `f(x)` that may only be true in the infinite - width / infinite number of samples limit, or may not apply to your - architecture. For most precise results in the context of linearized training - dynamics of a specific finite-width network, set both `trace_axes=()` and - `diagonal_axes=()` to obtain the kernel exactly of shape - `zip(f(x1).shape, f(x2).shape)`. + .. warning:: + Resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` + subject to `trace_axes` and `diagonal_axes` parameters, which make certain + assumptions about the outputs `f(x)` that may only be true in the infinite + width / infinite number of samples limit, or may not apply to your + architecture. For most precise results in the context of linearized training + dynamics of a specific finite-width network, set both `trace_axes=()` and + `diagonal_axes=()` to obtain the kernel exactly of shape + `zip(f(x1).shape, f(x2).shape)`. For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the @@ -327,66 +332,64 @@ def contract(out1, out2): class NtkImplementation(enum.IntEnum): """Implementation method of the underlying finite width NTK computation. - Below is a very brief summary of each method. For details see - "Fast Finite Width Neural Tangent Kernel", ICML 2022. - - `AUTO` (`0`) evaluates FLOPs of all other methods at compilation time, and - selects the fastest method. However, at the time it only works correctly on - TPUs, and on CPU/GPU can return wrong results, which is why it is not the - default. TODO(romann): revisit based on http://b/202218145. - - `JACOBIAN_CONTRACTION` (`1`) computes the NTK as the outer product of two - Jacobians, each computed using reverse-mode Autodiff (vector-Jacobian - products, VJPs). When JITted, the contraction is performed in a layerwise - fashion, so that entire Jacobians aren't necessarily instantiated in memory - at once, and the memory usage of the method can be lower than memory needed to - instantiate the two Jacobians. This method is best suited for networks with - small outputs (such as scalar outputs for binary classification or regression, - as opposed to 1000 ImageNet classes), and an expensive forward pass relative - to the number of parameters (such as CNNs, where forward pass reuses a small - filter bank many times). It is also the the most reliable method, since its - implementation is simplest, and reverse-mode Autodiff is most commonly used - and well tested elsewhere. For this reason it is set as the default. - - `NTK_VECTOR_PRODUCTS` (`2`) computes the NTK as a sequence of NTK-vector - products, similarly to how a Jacobian is computed as a sequence of - Jacobian-vector products (JVPs) or vector-Jacobian products (VJPs). This - amounts to using both forward (JVPs) and reverse (VJPs) mode Autodiff, and - allows to eliminate the Jacobian contraction at the expense of additional - forward passes. Therefore this method is recommended for networks with a cheap - forwards pass relative to the number of parameters (e.g. - fully-connected networks, where each parameter matrix is used only once in the - forward pass), and networks with large outputs (e.g. 1000 ImageNet classes). - Memory requirements of this method are same as `JACOBIAN_CONTRACTION` (`1`). - Due to reliance of forward-mode Autodiff, this method is slightly more prone - to JAX and XLA bugs than `JACOBIAN_CONTRACTION` (`1`), but overall is quite - simple and reliable. - - `STRUCTURED_DERIVATIVES` (`3`) uses a custom JAX interpreter to compute the - NTK more efficiently than other methods. It traverses the computational graph - of a function in the same order as during reverse-mode Autodiff, but instead - of computing VJPs, it directly computes MJJMPs, - "matrix-Jacobian-Jacobian-matrix" products, which arise in the computation of - an NTK. Each MJJMP computation relies on the structure in the Jacobians, hence - the name. This method can be dramatically faster (up to several orders of - magnitude) then other methods on - fully-connected networks, and is usually faster or equivalent on CNNs, - Transformers, and other architectures, but exact speedup (e.g. from no - speedup to 10X) depends on each specific setting. It can also use less memory - than other methods. In our experience it consistently outperforms other - methods in most settings. However, its implementation is significantly more - complex (hence bug-prone), and it doesn't yet support functions using more - exotic JAX primitives (e.g. `remat`, parallel collectives such as `psum`, - compiled loops, etc.), which is why it is highly-recommended to try, but not - set as the default yet. - - WARNING: since different implementations use different Autodiff modes, they - can return slightly different numerical values due to different order of - contractions. If out-of-bounds indexing happens anywhere in the computation, - results can be completely different - (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing). - Type conversion inside the computation can also lead to differently-typed - results. + Below is a very brief summary of each method. For details, please see "`Fast + Finite Width Neural Tangent Kernel `_". + + Attributes: + AUTO: + (or `0`) evaluates FLOPs of all other methods at compilation time, + and selects the fastest method. However, at the time it only works + correctly on TPUs, and on CPU/GPU can return wrong results, which is why + it is not the default. TODO(romann): revisit based on http://b/202218145. + + JACOBIAN_CONTRACTION: + (or `1`) computes the NTK as the outer product of two Jacobians, each + computed using reverse-mode Autodiff (vector-Jacobian products, VJPs). + When JITted, the contraction is performed in a layerwise fashion, so that + entire Jacobians aren't necessarily instantiated in memory at once, and + the memory usage of the method can be lower than memory needed to + instantiate the two Jacobians. This method is best suited for networks + with small outputs (such as scalar outputs for binary classification or + regression, as opposed to 1000 ImageNet classes), and an expensive + forward pass relative to the number of parameters (such as CNNs, where + forward pass reuses a small filter bank many times). It is also the the + most reliable method, since its implementation is simplest, and + reverse-mode Autodiff is most commonly used and well tested elsewhere. + For this reason it is set as the default. + + NTK_VECTOR_PRODUCTS: + (or `2`) computes the NTK as a sequence of NTK-vector products, similarly + to how a Jacobian is computed as a sequence of Jacobian-vector products + (JVPs) or vector-Jacobian products (VJPs). This amounts to using both + forward (JVPs) and reverse (VJPs) mode Autodiff, and allows to eliminate + the Jacobian contraction at the expense of additional forward passes. + Therefore this method is recommended for networks with a cheap forward + pass relative to the number of parameters (e.g. fully-connected networks, + where each parameter matrix is used only once in the forward pass), and + networks with large outputs (e.g. 1000 ImageNet classes). Memory + requirements of this method are same as :attr:`JACOBIAN_CONTRACTION` + (`1`). Due to reliance of forward-mode Autodiff, this method is slightly + more prone to JAX and XLA bugs than :attr:`JACOBIAN_CONTRACTION` (`1`), + but overall is quite simple and reliable. + + STRUCTURED_DERIVATIVES: + (or `3`) uses a custom JAX interpreter to compute the NTK more + efficiently than other methods. It traverses the computational graph of a + function in the same order as during reverse-mode Autodiff, but instead + of computing VJPs, it directly computes MJJMPs, + "matrix-Jacobian-Jacobian-matrix" products, which arise in the + computation of an NTK. Each MJJMP computation relies on the structure in + the Jacobians, hence the name. This method can be dramatically faster + (up to several orders of magnitude) then other methods on fully-connected + networks, and is usually faster or equivalent on CNNs, Transformers, and + other architectures, but exact speedup (e.g. from no speedup to 10X) + depends on each specific setting. It can also use less memory than other + methods. In our experience it consistently outperforms other methods in + most settings. However, its implementation is significantly more complex + (hence bug-prone), and it doesn't yet support functions using more exotic + JAX primitives (e.g. :obj:`jax.checkpoint`, parallel collectives such as + :obj:`jax.lax.psum`, compiled loops like :obj:`jax.lax.scan`, etc.), which + is why it is highly-recommended to try, but not set as the default yet. """ AUTO = 0 JACOBIAN_CONTRACTION = 1 @@ -902,14 +905,15 @@ def empirical_ntk_fn( 3) make sure to set `vmap_axes` correctly. 4) try different `implementation` values. - WARNING: Resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` - subject to `trace_axes` and `diagonal_axes` parameters, which make certain - assumptions about the outputs `f(x)` that may only be true in the infinite - width / infinite number of samples limit, or may not apply to your - architecture. For most precise results in the context of linearized training - dynamics of a specific finite-width network, set both `trace_axes=()` and - `diagonal_axes=()` to obtain the kernel exactly of shape - `zip(f(x1).shape, f(x2).shape)`. + .. warning:: + Resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` + subject to `trace_axes` and `diagonal_axes` parameters, which make certain + assumptions about the outputs `f(x)` that may only be true in the infinite + width / infinite number of samples limit, or may not apply to your + architecture. For most precise results in the context of linearized training + dynamics of a specific finite-width network, set both `trace_axes=()` and + `diagonal_axes=()` to obtain the kernel exactly of shape + `zip(f(x1).shape, f(x2).shape)`. For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the @@ -981,31 +985,34 @@ def empirical_ntk_fn( set to `None`, to avoid wrong (and potentially silent) results. implementation: - Applicable only to NTK, an `NtkImplementation` value (or an integer `0`, - `1`, `2`, or `3`). See the `NtkImplementation` enum docstring for details. + An :class:`NtkImplementation` value (or an :class:`int` `0`, `1`, `2`, + or `3`). See the :class:`NtkImplementation` docstring for details. _j_rules: - Internal debugging parameter, applicable only to NTK when - `implementation` is `STRUCTURED_DERIVATIVES` (`3`) or `AUTO` (`0`). Set to - `True` to allow custom Jacobian rules for intermediary primitive `dy/dw` - computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to - `False` to use JVPs or VJPs, via JAX's `jacfwd` or `jacrev`. Custom + Internal debugging parameter, applicable only when + `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` + (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow + custom Jacobian rules for intermediary primitive `dy/dw` computations for + MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to `False` to use + JVPs or VJPs, via JAX's :obj:`jax.jacfwd` or :obj:`jax.jacrev`. Custom Jacobian rules (`True`) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it to `False` could improve performance. _s_rules: - Internal debugging parameter, applicable only to NTK when - `implementation` is `STRUCTURED_DERIVATIVES` (`3`) or `AUTO` (`0`). Set to - `True` to allow efficient MJJMp rules for structured `dy/dw` primitive - Jacobians. In practice should be set to `True`, and setting it to `False` - can lead to dramatic deterioration of performance. + Internal debugging parameter, applicable only when + `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` + (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow + efficient MJJMp rules for structured `dy/dw` primitive Jacobians. In + practice should be set to `True`, and setting it to `False` can lead to + dramatic deterioration of performance. _fwd: - Internal debugging parameter, applicable only to NTK when - `implementation` is `STRUCTURED_DERIVATIVES` (`3`) or `AUTO` (`0`). Set to - `True` to allow `jvp` in intermediary primitive Jacobian `dy/dw` - computations, `False` to always use `vjp`. `None` to decide automatically + Internal debugging parameter, applicable only when + `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` + (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow + :obj:`jax.jvp` in intermediary primitive Jacobian `dy/dw` computations, + `False` to always use :obj:`jax.vjp`. `None` to decide automatically based on input/output sizes. Applicable when `_j_rules=False`, or when a primitive does not have a Jacobian rule. Should be set to `None` for best performance. @@ -1039,14 +1046,15 @@ def empirical_kernel_fn( ) -> EmpiricalGetKernelFn: r"""Returns a function that computes single draws from NNGP and NT kernels. - WARNING: resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` - subject to `trace_axes` and `diagonal_axes` parameters, which make certain - assumptions about the outputs `f(x)` that may only be true in the infinite - width / infinite number of samples limit, or may not apply to your - architecture. For most precise results in the context of linearized training - dynamics of a specific finite-width network, set both `trace_axes=()` and - `diagonal_axes=()` to obtain the kernel exactly of shape - `zip(f(x1).shape, f(x2).shape)`. + .. warning:: + Resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` + subject to `trace_axes` and `diagonal_axes` parameters, which make certain + assumptions about the outputs `f(x)` that may only be true in the infinite + width / infinite number of samples limit, or may not apply to your + architecture. For most precise results in the context of linearized training + dynamics of a specific finite-width network, set both `trace_axes=()` and + `diagonal_axes=()` to obtain the kernel exactly of shape + `zip(f(x1).shape, f(x2).shape)`. For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the @@ -1121,31 +1129,35 @@ def empirical_kernel_fn( set to `None`, to avoid wrong (and potentially silent) results. implementation: - Applicable only to NTK, an `NtkImplementation` value (or an integer `0`, - `1`, `2`, or `3`). See the `NtkImplementation` enum docstring for details. + Applicable only to NTK, an :class:`NtkImplementation` value (or an + :class:`int` `0`, `1`, `2`, or `3`). See the :class:`NtkImplementation` + docstring for details. _j_rules: Internal debugging parameter, applicable only to NTK when - `implementation` is `STRUCTURED_DERIVATIVES` (`3`) or `AUTO` (`0`). Set to - `True` to allow custom Jacobian rules for intermediary primitive `dy/dw` - computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to - `False` to use JVPs or VJPs, via JAX's `jacfwd` or `jacrev`. Custom + `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` + (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow + custom Jacobian rules for intermediary primitive `dy/dw` computations for + MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to `False` to use + JVPs or VJPs, via JAX's :obj:`jax.jacfwd` or :obj:`jax.jacrev`. Custom Jacobian rules (`True`) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it to `False` could improve performance. _s_rules: Internal debugging parameter, applicable only to NTK when - `implementation` is `STRUCTURED_DERIVATIVES` (`3`) or `AUTO` (`0`). Set to - `True` to allow efficient MJJMp rules for structured `dy/dw` primitive - Jacobians. In practice should be set to `True`, and setting it to `False` - can lead to dramatic deterioration of performance. + `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` + (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow + efficient MJJMp rules for structured `dy/dw` primitive Jacobians. In + practice should be set to `True`, and setting it to `False` can lead to + dramatic deterioration of performance. _fwd: Internal debugging parameter, applicable only to NTK when - `implementation` is `STRUCTURED_DERIVATIVES` (`3`) or `AUTO` (`0`). Set to - `True` to allow `jvp` in intermediary primitive Jacobian `dy/dw` - computations, `False` to always use `vjp`. `None` to decide automatically + `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` + (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow + :obj:`jax.jvp` in intermediary primitive Jacobian `dy/dw` computations, + `False` to always use :obj:`jax.vjp`. `None` to decide automatically based on input/output sizes. Applicable when `_j_rules=False`, or when a primitive does not have a Jacobian rule. Should be set to `None` for best performance. @@ -1410,10 +1422,10 @@ def expand(x: np.ndarray) -> np.ndarray: def _expand_dims( - x: Optional[PyTree], + x: Union[Optional[PyTree], UndefinedPrimal], axis: Optional[PyTree] ) -> Optional[PyTree]: - if axis is None or x is None: + if axis is None or x is None or isinstance(x, UndefinedPrimal): return x return tree_map(_expand_dims_array, x, axis) @@ -1757,7 +1769,9 @@ def read_cotangent(v: Var) -> Union[np.ndarray, Zero]: if not isinstance(cts_in, ShapedArray): raise TypeError(cts_in) trimmed_cts_in = _trim_cotangents(cts_in, structure) - eqn = _trim_eqn(eqn, i_eqn, trimmed_invals, trimmed_cts_in) + + if _s_rules: + eqn = _trim_eqn(eqn, i_eqn, trimmed_invals, trimmed_cts_in) def j_fn(invals): return _get_jacobian(eqn=eqn, diff --git a/neural_tangents/_src/monte_carlo.py b/neural_tangents/_src/monte_carlo.py index 220809a8..ea61c7c4 100644 --- a/neural_tangents/_src/monte_carlo.py +++ b/neural_tangents/_src/monte_carlo.py @@ -20,9 +20,10 @@ Note that the `monte_carlo_kernel_fn` accepts arguments like `batch_size`, `device_count`, and `store_on_device`, and is appropriately batched / -parallelized. You don't need to apply the `nt.batch` or `jax.jit` decorators to -it. Further, you do not need to apply `jax.jit` to the input `apply_fn` -function, as the resulting empirical kernel function is JITted internally. +parallelized. You don't need to apply the :obj:`~neural_tangents.batch` or +:obj:`jax.jit` decorators to it. Further, you do not need to apply +:obj:`jax.jit` to the input `apply_fn` function, as the resulting empirical +kernel function is JITted internally. """ @@ -139,13 +140,13 @@ def monte_carlo_kernel_fn( Args: init_fn: a function initializing parameters of the neural network. From - `jax.example_libraries.stax`: "takes an rng key and an input shape and - returns an `(output_shape, params)` pair". + :obj:`jax.example_libraries.stax`: "takes an rng key and an input shape + and returns an `(output_shape, params)` pair". apply_fn: a function computing the output of the neural network. - From `jax.example_libraries.stax`: "takes params, inputs, and an rng key - and applies the layer". + From :obj:`jax.example_libraries.stax`: "takes params, inputs, and an + rng key and applies the layer". key: RNG (`jax.random.PRNGKey`) for sampling random networks. Must have @@ -218,32 +219,35 @@ def monte_carlo_kernel_fn( set to `None`, to avoid wrong (and potentially silent) results. implementation: - Applicable only to NTK, an `NtkImplementation` value (or an integer `0`, - `1`, `2`, or `3`). See the `neural_tangents.NtkImplementation` enum + Applicable only to NTK, an :class:`NtkImplementation` value (or an + :class:`int` `0`, `1`, `2`, or `3`). See the :class:`NtkImplementation` docstring for details. _j_rules: Internal debugging parameter, applicable only to NTK when - `implementation` is `STRUCTURED_DERIVATIVES` (`3`) or `AUTO` (`0`). Set to - `True` to allow custom Jacobian rules for intermediary primitive `dy/dw` - computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to - `False` to use JVPs or VJPs, via JAX's `jacfwd` or `jacrev`. Custom + `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` + (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow + custom Jacobian rules for intermediary primitive `dy/dw` computations for + MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to `False` to use + JVPs or VJPs, via JAX's :obj:`jax.jacfwd` or :obj:`jax.jacrev`. Custom Jacobian rules (`True`) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it to `False` could improve performance. _s_rules: Internal debugging parameter, applicable only to NTK when - `implementation` is `STRUCTURED_DERIVATIVES` (`3`) or `AUTO` (`0`). Set to - `True` to allow efficient MJJMp rules for structured `dy/dw` primitive - Jacobians. In practice should be set to `True`, and setting it to `False` - can lead to dramatic deterioration of performance. + `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` + (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow + efficient MJJMp rules for structured `dy/dw` primitive Jacobians. In + practice should be set to `True`, and setting it to `False` can lead to + dramatic deterioration of performance. _fwd: Internal debugging parameter, applicable only to NTK when - `implementation` is `STRUCTURED_DERIVATIVES` (`3`) or `AUTO` (`0`). Set to - `True` to allow `jvp` in intermediary primitive Jacobian `dy/dw` - computations, `False` to always use `vjp`. `None` to decide automatically + `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` + (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow + :obj:`jax.jvp` in intermediary primitive Jacobian `dy/dw` computations, + `False` to always use :obj:`jax.vjp`. `None` to decide automatically based on input/output sizes. Applicable when `_j_rules=False`, or when a primitive does not have a Jacobian rule. Should be set to `None` for best performance. @@ -259,12 +263,12 @@ def monte_carlo_kernel_fn( >>> from jax import random >>> import neural_tangents as nt >>> from neural_tangents import stax - >>> + >>> # >>> key1, key2 = random.split(random.PRNGKey(1), 2) >>> x_train = random.normal(key1, (20, 32, 32, 3)) >>> y_train = random.uniform(key1, (20, 10)) >>> x_test = random.normal(key2, (5, 32, 32, 3)) - >>> + >>> # >>> init_fn, apply_fn, _ = stax.serial( >>> stax.Conv(128, (3, 3)), >>> stax.Relu(), @@ -274,12 +278,12 @@ def monte_carlo_kernel_fn( >>> stax.Flatten(), >>> stax.Dense(10) >>> ) - >>> + >>> # >>> n_samples = 200 >>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples) >>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk')) >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`. - >>> + >>> # >>> n_samples = [1, 10, 100, 1000] >>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, >>> n_samples) diff --git a/neural_tangents/_src/predict.py b/neural_tangents/_src/predict.py index d5512563..a1ce991c 100644 --- a/neural_tangents/_src/predict.py +++ b/neural_tangents/_src/predict.py @@ -18,12 +18,13 @@ function `predict_fn` that computes predictions on the train set / given test set / timesteps. -WARNING: `trace_axes` parameter supplied to prediction functions must match the -respective parameter supplied to the function used to compute the kernel. -Namely, this is the same `trace_axes` used to compute the empirical kernel -(`utils/empirical.py`; `diagonal_axes` must be `()`), or `channel_axis` in the -output of the top layer used to compute the closed-form kernel (`stax.py`; note -that closed-form kernels currently only support a single `channel_axis`). +.. warning:: + `trace_axes` parameter supplied to prediction functions must match the + respective parameter supplied to the function used to compute the kernel. + Namely, this is the same `trace_axes` used to compute the empirical kernel + (`utils/empirical.py`; `diagonal_axes` must be `()`), or `channel_axis` in the + output of the top layer used to compute the closed-form kernel (`stax.py`; + note that closed-form kernels currently only support a single `channel_axis`). """ @@ -44,8 +45,8 @@ from .utils.typing import Axes, Get, KernelFn -"""Alias for optional arrays or scalars.""" ArrayOrScalar = Union[None, int, float, np.ndarray] +"""Alias for optional arrays or scalars.""" class PredictFn(Protocol): @@ -88,22 +89,25 @@ def gradient_descent_mse( `k_train_train` is performed and cached for future invocations (or both, if the function is called on both finite and infinite (`t=None`) times). - [*] https://arxiv.org/abs/1806.07572 - [**] https://arxiv.org/abs/1902.06720 + [*] "`Neural Tangent Kernel: Convergence and Generalization in Neural Networks + `_" + + [**] "`Wide Neural Networks of Any Depth Evolve as Linear + Models Under Gradient Descent `_" Example: >>> import neural_tangents as nt - >>> + >>> # >>> t = 1e-7 >>> kernel_fn = nt.empirical_ntk_fn(f) >>> k_train_train = kernel_fn(x_train, None, params) >>> k_test_train = kernel_fn(x_test, x_train, params) - >>> + >>> # >>> predict_fn = nt.predict.gradient_descent_mse(k_train_train, y_train) - >>> + >>> # >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) - >>> + >>> # >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train) @@ -111,17 +115,22 @@ def gradient_descent_mse( k_train_train: kernel on the training data. Must have the shape of `zip(y_train.shape, y_train.shape)` with `trace_axes` absent. + y_train: targets for the training data. + learning_rate: learning rate, step size. + diag_reg: a scalar representing the strength of the diagonal regularization for `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during Cholesky factorization or eigendecomposition. + diag_reg_absolute_scale: `True` for `diag_reg` to represent regularization in absolute units, `False` to be `diag_reg * np.mean(np.trace(k_train_train))`. + trace_axes: `f(x_train)` axes such that `k_train_train` lacks these pairs of dimensions and is to be interpreted as :math:`\Theta \otimes I`, i.e. @@ -228,12 +237,15 @@ def predict_fn( using identity or linear solve for train and test predictions respectively instead of eigendecomposition, saving time and precision. Equivalent of training steps (but can be fractional). + fx_train_0: output of the network at `t == 0` on the training set. `fx_train_0=None` means to not compute predictions on the training set. + fx_test_0: output of the network at `t == 0` on the test set. `fx_test_0=None` means to not compute predictions on the test set. + k_test_train: kernel relating test data with training data. Must have the shape of `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass @@ -262,7 +274,21 @@ def predict_fn( @dataclasses.dataclass class ODEState: - """ODE state dataclass holding outputs and auxiliary variables.""" + """ODE state dataclass holding outputs and auxiliary variables. + + Attributes: + fx_train: + training set outputs. + + fx_test: + test set outputs. + + qx_train: + training set auxiliary state variable (e.g. momentum). + + qx_test: + test set auxiliary state variable (e.g. momentum). + """ fx_train: Optional[np.ndarray] = None fx_test: Optional[np.ndarray] = None qx_train: Optional[np.ndarray] = None @@ -293,38 +319,40 @@ def gradient_descent( r"""Predicts the outcome of function space training using gradient descent. Uses an ODE solver. If `momentum != None`, solves a continuous-time version of - gradient descent with momentum (note: this case uses standard momentum as - opposed to Nesterov momentum). + gradient descent with momentum. - Solves the function space ODE for [momentum] gradient descent with a given - `loss` (detailed in [*]) given a Neural Tangent Kernel[s] over the dataset[s] - at arbitrary time[s] (step[s]) `t`. Note that for gradient descent - `absolute_time = learning_rate * t` and the scales of the learning rate and - query step[s] `t` are interchangeable. However, the momentum gradient descent - ODE is solved in the units of `learning_rate**0.5`, and therefore - `absolute_time = learning_rate**0.5 * t`, hence the `learning_rate` and - training time[s] (step[s]) `t` scales are not interchangeable. + .. note:: + We use standard momentum as opposed to Nesterov momentum. - [*] https://arxiv.org/abs/1902.06720 + Solves the function space ODE for [momentum] gradient descent with a given + `loss` (detailed in "`Wide Neural Networks of Any Depth Evolve as Linear + Models Under Gradient Descent `_".) given a + Neural Tangent Kernel[s] over the dataset[s] at arbitrary time[s] (step[s]) + `t`. Note that for gradient descent `absolute_time = learning_rate * t` and + the scales of the learning rate and query step[s] `t` are interchangeable. + However, the momentum gradient descent ODE is solved in the units of + `learning_rate**0.5`, and therefore `absolute_time = learning_rate**0.5 * t`, + hence the `learning_rate` and training time[s] (step[s]) `t` scales are not + interchangeable. Example: >>> import neural_tangents as nt - >>> + >>> # >>> t = 1e-7 >>> learning_rate = 1e-2 >>> momentum = 0.9 - >>> + >>> # >>> kernel_fn = nt.empirical_ntk_fn(f) >>> k_test_train = kernel_fn(x_test, x_train, params) - >>> + >>> # >>> from jax.nn import log_softmax >>> cross_entropy = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat) >>> predict_fn = nt.redict.gradient_descent( >>> cross_entropy, k_train_train, y_train, learning_rate, momentum) - >>> + >>> # >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) - >>> + >>> # >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train) @@ -333,15 +361,20 @@ def gradient_descent( a loss function whose signature is `loss(f(x_train), y_train)`. Note: the loss function should treat the batch and output dimensions symmetrically. + k_train_train: kernel on the training data. Must have the shape of `zip(y_train.shape, y_train.shape)` with `trace_axes` absent. + y_train: targets for the training data. + learning_rate: learning rate, step size. + momentum: momentum scalar. + trace_axes: `f(x_train)` axes such that `k_train_train` lacks these pairs of dimensions and is to be interpreted as :math:`\Theta \otimes I`, i.e. @@ -440,6 +473,7 @@ def predict_fn( a scalar or array of scalars of any shape in strictly increasing order. `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of training steps (but can be fractional). + fx_train_or_state_0: either (a) output of the network at `t == 0` on the training set or (b) complete ODE state (`predict.ODEState`). Pass an ODE state if you want @@ -451,9 +485,11 @@ def predict_fn( `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an ODE state is returned. `fx_train_0=None` means to not compute predictions on the training set. + fx_test_0: output of the network at `t == 0` on the test set. `fx_test_0=None` means to not compute predictions on the test set. + k_test_train: kernel relating test data with training data. Must have the shape of `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass @@ -508,7 +544,15 @@ def predict_fn( class Gaussian(NamedTuple): - """A `(mean, covariance)` convenience namedtuple.""" + """A `(mean, covariance)` convenience namedtuple. + + Attributes: + mean: + Mean of shape equal to the shape of the function outputs. + + covariance: + Covariance of shape equal to the shape of the respective NTK/NNGP kernel. + """ mean: np.ndarray covariance: np.ndarray @@ -524,8 +568,10 @@ def gp_inference( NNGP - the exact posterior of an infinitely wide Bayesian NN. NTK - exact distribution of an infinite ensemble of infinitely wide NNs trained with gradient flow for infinite time. NTKGP - posterior of a GP (Gaussian process) - with the NTK covariance (see https://arxiv.org/abs/2007.05864 for how this - can correspond to infinite ensembles of infinitely wide NNs as well). + with the NTK covariance (see + "`Bayesian Deep Ensembles via the Neural Tangent Kernel + `_" for how this can correspond to infinite + ensembles of infinitely wide NNs as well). Note that first invocation of the returned `predict_fn` will be slow and allocate a lot of memory for its whole lifetime, as a Cholesky factorization @@ -534,21 +580,26 @@ def gp_inference( Args: k_train_train: - train-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c) - `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels - for arguments provided to the returned `predict_fn` function. For - example, if you request to compute posterior test [only] NTK covariance in - future `predict_fn` invocations, `k_train_train` must contain both `ntk` - and `nngp` kernels. + train-train kernel. Can be (a) :class:`jax.numpy.ndarray`, + (b) `Kernel` namedtuple, (c) :class:`~neural_tangents.Kernel` object. + Must contain the necessary `nngp` and/or `ntk` kernels for arguments + provided to the returned `predict_fn` function. For example, if you + request to compute posterior test [only] NTK covariance in future + `predict_fn` invocations, `k_train_train` must contain both `ntk` and + `nngp` kernels. + y_train: train targets. + diag_reg: a scalar representing the strength of the diagonal regularization for `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during Cholesky factorization. + diag_reg_absolute_scale: `True` for `diag_reg` to represent regularization in absolute units, `False` to be `diag_reg * np.mean(np.trace(k_train_train))`. + trace_axes: `f(x_train)` axes such that `k_train_train`, `k_test_train`[, and `k_test_test`] lack these pairs of dimensions and @@ -587,33 +638,38 @@ def predict_fn(get: Optional[Get] = None, Args: get: string, the mode of the Gaussian process, either "nngp", "ntk", "ntkgp", - (see https://arxiv.org/abs/2007.05864) or a tuple, or `None`. If `None` + (see "`Bayesian Deep Ensembles via the Neural Tangent Kernel + `_") or a tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are returned. + k_test_train: - test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c) - `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels - for arguments provided to the returned `predict_fn` function. For - example, if you request to compute posterior test [only] NTK covariance, - `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`, - returns predictions on the training set. Note that train-set outputs are - always `N(y_train, 0)` and mostly returned for API consistency. + test-train kernel. Can be (a) :class:`jax.numpy.ndarray`, + (b) `Kernel` namedtuple, (c) :class:`~neural_tangents.Kernel` object. + Must contain the necessary `nngp` and/or `ntk` kernels for arguments + provided to the returned `predict_fn` function. For example, if you + request to compute posterior test [only] NTK covariance, `k_test_train` + must contain both `ntk` and `nngp` kernels. If `None`, returns + predictions on the training set. Note that train-set outputs are always + `N(y_train, 0)` and mostly returned for API consistency. + k_test_test: - test-test kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c) - `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels - for arguments provided to the returned `predict_fn` function. Provide - if you want to compute test-test posterior covariance. - `k_test_test=None` means to not compute it. If `k_test_train is None`, - pass any non-`None` value (e.g. `True`) if you want to get - non-regularized (`diag_reg=0`) train-train posterior covariance. Note - that non-regularized train-set outputs will always be the zero-variance - Gaussian `N(y_train, 0)` and mostly returned for API consistency. For - regularized train-set posterior outputs according to a positive - `diag_reg`, pass `k_test_train=k_train_train`, and, optionally, + test-test kernel. Can be (a) :class:`jax.numpy.ndarray`, + (b) `Kernel` namedtuple, (c) :class:`~neural_tangents.Kernel` object. + Must contain the necessary `nngp` and/or `ntk` kernels for arguments + provided to the returned `predict_fn` function. Provide if you want to + compute test-test posterior covariance. `k_test_test=None` means to not + compute it. If `k_test_train is None`, pass any non-`None` value (e.g. + `True`) if you want to get non-regularized (`diag_reg=0`) train-train + posterior covariance. Note that non-regularized train-set outputs will + always be the zero-variance Gaussian `N(y_train, 0)` and mostly + returned for API consistency. For regularized train-set posterior + outputs according to a positive `diag_reg`, pass + `k_test_train=k_train_train`, and, optionally, `k_test_test=nngp_train_train`. Returns: - Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP - posterior on the `test` set. + Either a :class:`Gaussian` `(mean, variance)` namedtuple or `mean` of the + GP posterior on the `test` set. """ if get is None: get = ('nngp', 'ntk') @@ -715,14 +771,14 @@ def gradient_descent_mse_ensemble( Args: kernel_fn: A kernel function that computes NNGP and/or NTK. Must have a signature - `kernel_fn(x1, x2, get, **kernel_fn_kwargs)` and return a `Kernel` object - or a `namedtuple` with `nngp` and/or `ntk` attributes. Therefore, it can - be an `AnalyticKernelFn`, but also a `MonteCarloKernelFn`, or an - `EmpiricalKernelFn` (but only `nt.empirical_kernel_fn` and not - `nt.empirical_ntk_fn` or `nt.empirical_nngp_fn`, since the latter - two do not accept a `get` argument). Note that for meaningful outputs, - the kernel function must represent or at least approximate the - infinite-width kernel. + `kernel_fn(x1, x2, get, **kernel_fn_kwargs)` and return a + :class:`~neural_tangents.Kernel` object or a `namedtuple` with `nngp` + and/or `ntk` attributes. Therefore, it can be an `AnalyticKernelFn`, but + also a `MonteCarloKernelFn`, or an `EmpiricalKernelFn` (but only + `nt.empirical_kernel_fn` and not `nt.empirical_ntk_fn` or + `nt.empirical_nngp_fn`, since the latter two do not accept a `get` + argument). Note that for meaningful outputs, the kernel function must + represent or at least approximate the infinite-width kernel. x_train: training inputs. @@ -891,16 +947,20 @@ def predict_fn(t: Optional[ArrayOrScalar] = None, infinity and returns the same result as `t=np.inf`, but is computed using linear solve for test predictions instead of eigendecomposition, saving time and precision. + x_test: test inputs. `None` means to return non-regularized (`diag_reg=0`) predictions on the train-set inputs. For regularized predictions, pass `x_test=x_train`. + get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple. `get=None` is equivalent to `get=("nngp", "ntk")`. + compute_cov: if `True` computing both `mean` and `variance` and only `mean` otherwise. + **kernel_fn_test_test_kwargs: optional keyword arguments passed to `kernel_fn`. See also `kernel_fn_train_train_kwargs` argument of the parent function. @@ -1045,17 +1105,21 @@ def max_learning_rate( contraction, which is `2 * batch_size * output_size * lambda_max(NTK)`. When `momentum > 0`, we use `2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK)` (see - `The Dynamics of Momentum` section in https://distill.pub/2017/momentum/). + *The Dynamics of Momentum* section in + "`Why Momentum Really Works `_"). Args: ntk_train_train: analytic or empirical NTK on the training data. + y_train_size: total training set output size, i.e. `f(x_train).size == y_train.size`. If `output_size=None` it is inferred from `ntk_train_train.shape` assuming `trace_axes=()`. + momentum: The `momentum` for momentum optimizers. + eps: a float to avoid zero divisor. @@ -1113,10 +1177,13 @@ def _get_fns_in_eigenbasis( Args: k_train_train: an n x n matrix. + diag_reg: diagonal regularizer strength. + diag_reg_absolute_scale: `True` to use absolute (vs relative to mean trace) regulatization. + fns: a sequence of functions that add on the eigenvalues (evals, dt) -> modified_evals. diff --git a/neural_tangents/_src/stax/branching.py b/neural_tangents/_src/stax/branching.py index cb73296b..6e0c207a 100644 --- a/neural_tangents/_src/stax/branching.py +++ b/neural_tangents/_src/stax/branching.py @@ -13,6 +13,7 @@ # limitations under the License. """Branching functions. + These layers split an input into multiple branches or fuse multiple inputs from several branches into one. """ @@ -31,7 +32,7 @@ @layer def FanOut(num: int) -> InternalLayer: - """Layer construction function for a fan-out layer. + """Fan-out. This layer takes an input and produces `num` copies that can be fed into different branches of a neural network (for example with residual @@ -51,7 +52,7 @@ def FanOut(num: int) -> InternalLayer: @layer @supports_masking(remask_kernel=False) def FanInSum() -> InternalLayerMasked: - """Layer construction function for a fan-in sum layer. + """Fan-in sum. This layer takes a number of inputs (e.g. produced by `FanOut`) and sums the inputs to produce a single output. @@ -113,7 +114,7 @@ def mask_fn(mask, input_shape): @layer @supports_masking(remask_kernel=False) def FanInProd() -> InternalLayerMasked: - """Layer construction function for a fan-in product layer. + """Fan-in product. This layer takes a number of inputs (e.g. produced by `FanOut`) and elementwise-multiplies the inputs to produce a single output. @@ -179,9 +180,9 @@ def mask_fn(mask, input_shape): @layer @supports_masking(remask_kernel=False) def FanInConcat(axis: int = -1) -> InternalLayerMasked: - """Layer construction function for a fan-in concatenation layer. + """Fan-in concatenation. - Based on `jax.example_libraries.stax.FanInConcat`. + Based on :obj:`jax.example_libraries.stax.FanInConcat`. Args: axis: Specifies the axis along which input tensors should be concatenated. diff --git a/neural_tangents/_src/stax/combinators.py b/neural_tangents/_src/stax/combinators.py index 38e73b89..9be4cbdc 100644 --- a/neural_tangents/_src/stax/combinators.py +++ b/neural_tangents/_src/stax/combinators.py @@ -15,21 +15,22 @@ """Layer combinators.""" import operator as op -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List import warnings -from jax import random + import frozendict +from jax import random import jax.example_libraries.stax as ostax from .requirements import Diagonal, get_req, layer, requires from ..utils.kernel import Kernel -from ..utils.typing import InternalLayer, Layer, LayerKernelFn, NTTree, NTTrees, Shapes, PyTree +from ..utils.typing import InternalLayer, Layer, LayerKernelFn, NTTree, NTTrees, Shapes @layer def serial(*layers: Layer) -> InternalLayer: """Combinator for composing layers in serial. - Based on `jax.example_libraries.stax.serial`. + Based on :obj:`jax.example_libraries.stax.serial`. Args: *layers: @@ -57,9 +58,10 @@ def kernel_fn(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]: def parallel(*layers: Layer) -> InternalLayer: """Combinator for composing layers in parallel. - The layer resulting from this combinator is often used with the `FanOut` and - `FanInSum`/`FanInConcat` layers. Based on - `jax.example_libraries.stax.parallel`. + The layer resulting from this combinator is often used with the + :obj:`~neural_tangents.stax.FanOut`, :obj:`~neural_tangents.stax.FanInSum`, + and :obj:`~neural_tangents.stax.FanInConcat` layers. Based on + :obj:`jax.example_libraries.stax.parallel`. Args: *layers: diff --git a/neural_tangents/_src/stax/elementwise.py b/neural_tangents/_src/stax/elementwise.py index cb99ab1a..02ea1af8 100644 --- a/neural_tangents/_src/stax/elementwise.py +++ b/neural_tangents/_src/stax/elementwise.py @@ -122,8 +122,8 @@ def Gelu( Args: approximate: only relevant for finite-width network, `apply_fn`. If `True`, computes - an approximation via `tanh`, see https://arxiv.org/abs/1606.08415 and - `jax.nn.gelu` for details. + an approximation via `tanh`, see "`Gaussian Error Linear Units (GELUs) + `_" and :obj:`jax.nn.gelu` for details. Returns: `(init_fn, apply_fn, kernel_fn)`. @@ -132,7 +132,11 @@ def fn(x): return jax.nn.gelu(x, approximate=approximate) def kernel_fn(k: Kernel) -> Kernel: - """Compute kernels after a `Gelu` layer; NNGP see `arXiv:2002.08517`.""" + """Compute kernels after a `Gelu` layer. + + For NNGP see "`Avoiding Kernel Fixed Points: Computing with ELU and GELU + Infinite Networks `_". + """ cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk cov1_plus_1 = cov1 + 1 @@ -282,7 +286,6 @@ def Rbf( def fn(x): return np.sqrt(2) * np.sin(np.sqrt(2 * gamma) * x + np.pi/4) - def kernel_fn(k: Kernel) -> Kernel: """Compute new kernels after an `Rbf` layer.""" cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk @@ -340,7 +343,8 @@ def fn(x): def kernel_fn(k: Kernel) -> Kernel: """Compute new kernels after an `ABRelu` layer. - See https://arxiv.org/pdf/1711.09090.pdf for the leaky ReLU derivation. + See "`Invariance of Weight Distributions in Rectified MLPs + `_" for the leaky ReLU derivation. """ cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk @@ -575,7 +579,8 @@ def ExpNormalized( do_clip: bool = False) -> InternalLayer: """Simulates the "Gaussian normalized kernel". - Source: https://arxiv.org/abs/2003.02237.pdf, page 6. + See page 6 in + "`Neural Kernels Without Tangents `_". Args: gamma: exponent scalar coefficient. @@ -693,18 +698,18 @@ def Elementwise( Example: >>> fn = jax.scipy.special.erf # type: Callable[[float], float] - >>> + >>> # >>> def nngp_fn(cov12: float, var1: float, var2: float) -> float: >>> prod = (1 + 2 * var1) * (1 + 2 * var2) >>> return np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi - >>> + >>> # >>> # Use autodiff and vectorization to construct the layer: >>> _, _, kernel_fn_auto = stax.Elementwise(fn, nngp_fn) - >>> + >>> # >>> # Use custom pre-derived expressions >>> # (should be faster and more numerically stable): >>> _, _, kernel_fn_stax = stax.Erf() - >>> + >>> # >>> kernel_fn_auto(x1, x2) == kernel_fn_stax(x1, x2) # usually `True`. Args: @@ -748,9 +753,10 @@ def Elementwise( else: if d_nngp_fn is None: + url = 'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where' warnings.warn( - 'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of ' - 'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.') + f'Using JAX autodiff to compute the `fn` derivative for NTK. Beware ' + f'of {url}.') d_nngp_fn = np.vectorize(grad(nngp_fn)) def kernel_fn(k: Kernel) -> Kernel: @@ -809,9 +815,10 @@ def ElementwiseNumerical( quad_points = osp.special.roots_hermite(deg) if df is None: + url = 'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where' warnings.warn( - 'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of ' - 'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.') + f'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of ' + f'{url}.') df = np.vectorize(grad(fn)) def kernel_fn(k: Kernel) -> Kernel: diff --git a/neural_tangents/_src/stax/linear.py b/neural_tangents/_src/stax/linear.py index d93582ff..b569f2b1 100644 --- a/neural_tangents/_src/stax/linear.py +++ b/neural_tangents/_src/stax/linear.py @@ -14,7 +14,6 @@ """Linear functions.""" - import enum import functools import operator as op @@ -40,20 +39,51 @@ class Padding(enum.Enum): - """Type of padding in pooling and convolutional layers.""" + """Type of padding in pooling and convolutional layers. + + Attributes: + CIRCULAR: + circular padding, as if the input were a torus. + + SAME: + same, a.k.a. zero padding. + + VALID: + valid, a.k.a. no padding. + """ CIRCULAR = 'CIRCULAR' SAME = 'SAME' VALID = 'VALID' class _Pooling(enum.Enum): - """Type of pooling in pooling layers.""" + """Type of pooling in pooling layers. + + Attributes: + AVG: + average pooling, the output is normalized by the input receptive field + size. + + SUM: + sum pooling, no normalization. + """ AVG = 'AVG' SUM = 'SUM' class AggregateImplementation(enum.Enum): - """Implementation of the `Aggregate` layer.""" + """Implementation of the :obj:`Aggregate` layer. + + See :obj:`Aggregate` docstring for details. + + Attributes: + DENSE: + Is recommended for dense graphs, where the number of edges `E` is + proportional to the number of vertices `V` to the power of 1.5 or more. + + SPARSE: + Is recommended for sparse graphs, where `E ~ O(V)` or less. + """ DENSE = 'DENSE' SPARSE = 'SPARSE' @@ -64,9 +94,9 @@ class AggregateImplementation(enum.Enum): @layer @supports_masking(remask_kernel=False) def Identity() -> InternalLayer: - """Layer construction function for an identity layer. + """Identity (no-op). - Based on `jax.example_libraries.stax.Identity`. + Based on :obj:`jax.example_libraries.stax.Identity`. Returns: `(init_fn, apply_fn, kernel_fn)`. @@ -87,7 +117,7 @@ def DotGeneral( batch_axis: int = 0, channel_axis: int = -1 ) -> InternalLayerMasked: - r"""Layer constructor for a constant (non-trainable) rhs/lhs Dot General. + r"""Constant (non-trainable) rhs/lhs Dot General. Dot General allows to express any linear transformation on the inputs, including but not limited to matrix multiplication, pooling, convolutions, @@ -100,40 +130,40 @@ def DotGeneral( on whether `lhs` or `rhs` is specified (not `None`). Example: - >>> from jax import random - >>> import jax.numpy as np - >>> from neural_tangents import stax - >>> - >>> # Two time series stacked along the second (H) dimension. - >>> x = random.normal(random.PRNGKey(1), (5, 2, 32, 3)) # NHWC - >>> - >>> # Multiply all outputs by a scalar: - >>> nn = stax.serial( - >>> stax.Conv(128, (1, 3)), - >>> stax.Relu(), - >>> stax.DotGeneral(rhs=2.), # output shape is (5, 2, 30, 128) - >>> stax.GlobalAvgPool() # (5, 128) - >>> ) - >>> - >>> # Subtract second time series from the first one: - >>> nn = stax.serial( - >>> stax.Conv(128, (1, 3)), - >>> stax.Relu(), - >>> stax.DotGeneral( - >>> rhs=np.array([1., -1.]), - >>> dimension_numbers=(((1,), (0,)), ((), ()))), # (5, 30, 128) - >>> stax.GlobalAvgPool() # (5, 128) - >>> ) - >>> - >>> # Flip outputs with each other - >>> nn = stax.serial( - >>> stax.Conv(128, (1, 3)), - >>> stax.Relu(), - >>> stax.DotGeneral( - >>> lhs=np.array([[0., 1.], [1., 0.]]), - >>> dimension_numbers=(((1,), (1,)), ((), ()))), # (5, 2, 30, 128) - >>> stax.GlobalAvgPool() # (5, 128) - >>> ) + >>> from jax import random + >>> import jax.numpy as np + >>> from neural_tangents import stax + >>> # + >>> # Two time series stacked along the second (H) dimension. + >>> x = random.normal(random.PRNGKey(1), (5, 2, 32, 3)) # NHWC + >>> # + >>> # Multiply all outputs by a scalar: + >>> nn = stax.serial( + >>> stax.Conv(128, (1, 3)), + >>> stax.Relu(), + >>> stax.DotGeneral(rhs=2.), # output shape is (5, 2, 30, 128) + >>> stax.GlobalAvgPool() # (5, 128) + >>> ) + >>> # + >>> # Subtract second time series from the first one: + >>> nn = stax.serial( + >>> stax.Conv(128, (1, 3)), + >>> stax.Relu(), + >>> stax.DotGeneral( + >>> rhs=np.array([1., -1.]), + >>> dimension_numbers=(((1,), (0,)), ((), ()))), # (5, 30, 128) + >>> stax.GlobalAvgPool() # (5, 128) + >>> ) + >>> # + >>> # Flip outputs with each other + >>> nn = stax.serial( + >>> stax.Conv(128, (1, 3)), + >>> stax.Relu(), + >>> stax.DotGeneral( + >>> lhs=np.array([[0., 1.], [1., 0.]]), + >>> dimension_numbers=(((1,), (1,)), ((), ()))), # (5, 2, 30, 128) + >>> stax.GlobalAvgPool() # (5, 128) + >>> ) See Also: https://www.tensorflow.org/xla/operation_semantics#dotgeneral @@ -221,9 +251,11 @@ def Aggregate( to_dense: Optional[Callable[[np.ndarray], np.ndarray]] = lambda p: p, implementation: str = AggregateImplementation.DENSE.value ) -> InternalLayer: - r"""Layer constructor for aggregation operator (graphical neural network). + r"""Aggregation operator (graphical neural network). - See e.g. https://arxiv.org/abs/1905.13192. + See e.g. + "`Graph Neural Tangent Kernel: Fusing Graph Neural Networks with Graph Kernels + `_". Specifically, each `N+2`-D `input` of shape `(batch, X_1, ..., X_N, channels)` (subject to `batch_axis` and `channel_axis`) is accompanied by an array @@ -289,89 +321,87 @@ def Aggregate( sparse and dense patterns. Example: - >>> # 1D inputs - >>> x = random.normal(random.PRNGKey(1), (5, 3, 32)) # NCH - >>> - >>> # 1) NHH dense binary adjacency matrix - >>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32)) - >>> # `A[n, h1, h2] == True` - >>> # means an edge between tokens `h1` and `h2` in sample `n`. - >>> - >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2, - >>> batch_axis=0, - >>> channel_axis=1) - >>> - >>> out = apply_fn((), x, pattern=A) - >>> # output is the same as `x @ A` of shape (5, 3, 32) - >>> - >>> # Sparse NHH binary pattern with 10 edges - >>> n_edges = 10 - >>> A_sparse = random.randint(random.PRNGKey(3), - >>> shape=(x.shape[0], n_edges, 1, 2), - >>> minval=0, - >>> maxval=x.shape[2]) - >>> - >>> # Setting `implementation="SPARSE"` to invoke the segment sum - >>> # implementation. - >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2, - >>> batch_axis=0, - >>> channel_axis=1, - >>> implementation="SPARSE") - >>> - >>> out = apply_fn((), x, pattern=A_sparse) - >>> # output is of shape (5, 3, 32), computed via `jax.ops.segment_sum`. - >>> - >>> # 2D inputs - >>> x = random.normal(random.PRNGKey(1), (5, 3, 32, 16)) # NCHW - >>> - >>> # 2) NHWHW dense binary adjacency matrix - >>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 16, 32, 16)) - >>> # `A[n, h1, w1, h2, w2] == True` - >>> # means an edge between pixels `(h1, w1)` and `(h2, w2)` in image `n`. - >>> - >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(2, 3), - >>> batch_axis=0, - >>> channel_axis=1) - >>> - >>> out = apply_fn((), x, pattern=A) - >>> # output is of shape (5, 3, 32, 16), the same as - >>> # `(x.reshape((5, 3, 32 * 16)) @ A.reshape((5, 32 * 16, 32 * 16)) - >>> # ).reshape(x.shape)` - >>> - >>> - >>> # 3) NWW binary adjacency matrix - >>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 16, 16)) - >>> # `A[n, w1, w2] == True` - >>> # means an edge between rows `w1` and `w2` in image `n`. - >>> - >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(3,), - >>> batch_axis=0, - >>> channel_axis=1) - >>> - >>> out = apply_fn((), x, pattern=A) - >>> # output is of shape (5, 3, 32, 16), the same as - >>> # `(x.reshape((5, 3 * 32, 16)) @ A).reshape(x.shape)` - >>> - >>> - >>> # 4) Infinite width example - >>> x1 = random.normal(random.PRNGKey(1), (5, 3, 32)) # NCH - >>> x2 = random.normal(random.PRNGKey(2), (2, 3, 32)) # NCH - >>> - >>> # NHH binary adjacency matrices - >>> A1 = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32)) - >>> A2 = random.bernoulli(random.PRNGKey(2), 0.5, (2, 32, 32)) - >>> - >>> _, _, kernel_fn_id = stax.Identity() - >>> - >>> _, _, kernel_fn_agg = stax.Aggregate(aggregate_axis=2, - >>> batch_axis=0, - >>> channel_axis=1) - >>> - >>> nngp = kernel_fn_id(x1, x2, get='nngp', channel_axis=1) - >>> # initial NNGP of shape (5, 2, 32, 32) - >>> K_agg = kernel_fn_agg(x1, x2, get='nngp', pattern=(A1, A2)) - >>> # output NNGP of same shape (5, 2, 32, 32): - >>> # `K_agg[n1, n2] == A1[n1].T @ nngp[n1, n2] @ A2[n2]` + >>> # 1D inputs + >>> x = random.normal(random.PRNGKey(1), (5, 3, 32)) # NCH + >>> # + >>> # 1) NHH dense binary adjacency matrix + >>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32)) + >>> # `A[n, h1, h2] == True` + >>> # means an edge between tokens `h1` and `h2` in sample `n`. + >>> # + >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2, + >>> batch_axis=0, + >>> channel_axis=1) + >>> # + >>> out = apply_fn((), x, pattern=A) + >>> # output is the same as `x @ A` of shape (5, 3, 32) + >>> # + >>> # Sparse NHH binary pattern with 10 edges + >>> n_edges = 10 + >>> A_sparse = random.randint(random.PRNGKey(3), + >>> shape=(x.shape[0], n_edges, 1, 2), + >>> minval=0, + >>> maxval=x.shape[2]) + >>> # + >>> # Setting `implementation="SPARSE"` to invoke the segment sum + >>> # implementation. + >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2, + >>> batch_axis=0, + >>> channel_axis=1, + >>> implementation="SPARSE") + >>> # + >>> out = apply_fn((), x, pattern=A_sparse) + >>> # output is of shape (5, 3, 32), computed via `jax.ops.segment_sum`. + >>> # + >>> # 2D inputs + >>> x = random.normal(random.PRNGKey(1), (5, 3, 32, 16)) # NCHW + >>> # + >>> # 2) NHWHW dense binary adjacency matrix + >>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 16, 32, 16)) + >>> # `A[n, h1, w1, h2, w2] == True` + >>> # means an edge between pixels `(h1, w1)` and `(h2, w2)` in image `n`. + >>> # + >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(2, 3), + >>> batch_axis=0, + >>> channel_axis=1) + >>> # + >>> out = apply_fn((), x, pattern=A) + >>> # output is of shape (5, 3, 32, 16), the same as + >>> # `(x.reshape((5, 3, 32 * 16)) @ A.reshape((5, 32 * 16, 32 * 16)) + >>> # ).reshape(x.shape)` + >>> # + >>> # 3) NWW binary adjacency matrix + >>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 16, 16)) + >>> # `A[n, w1, w2] == True` + >>> # means an edge between rows `w1` and `w2` in image `n`. + >>> # + >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(3,), + >>> batch_axis=0, + >>> channel_axis=1) + >>> # + >>> out = apply_fn((), x, pattern=A) + >>> # output is of shape (5, 3, 32, 16), the same as + >>> # `(x.reshape((5, 3 * 32, 16)) @ A).reshape(x.shape)` + >>> # + >>> # 4) Infinite width example + >>> x1 = random.normal(random.PRNGKey(1), (5, 3, 32)) # NCH + >>> x2 = random.normal(random.PRNGKey(2), (2, 3, 32)) # NCH + >>> # + >>> # NHH binary adjacency matrices + >>> A1 = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32)) + >>> A2 = random.bernoulli(random.PRNGKey(2), 0.5, (2, 32, 32)) + >>> # + >>> _, _, kernel_fn_id = stax.Identity() + >>> # + >>> _, _, kernel_fn_agg = stax.Aggregate(aggregate_axis=2, + >>> batch_axis=0, + >>> channel_axis=1) + >>> # + >>> nngp = kernel_fn_id(x1, x2, get='nngp', channel_axis=1) + >>> # initial NNGP of shape (5, 2, 32, 32) + >>> K_agg = kernel_fn_agg(x1, x2, get='nngp', pattern=(A1, A2)) + >>> # output NNGP of same shape (5, 2, 32, 32): + >>> # `K_agg[n1, n2] == A1[n1].T @ nngp[n1, n2] @ A2[n2]` Args: aggregate_axis: @@ -399,7 +429,7 @@ def Aggregate( (`E ~> O(V^1.5)`), while `"SPARSE"` uses `jax.ops.segment_sum` and is recommended for sparse graphs (`E ~< O(V)`). Note that different `implementation`s require different `pattern` array format - see the - layer docstring above for details. + :obj:`Aggregate` layer docstring above for details. Returns: `(init_fn, apply_fn, kernel_fn)`. @@ -725,9 +755,9 @@ def Dense( parameterization: str = 'ntk', s: Tuple[int, int] = (1, 1), ) -> InternalLayerMasked: - r"""Layer constructor function for a dense (fully-connected) layer. + r"""Dense (fully-connected, matrix product). - Based on `jax.example_libraries.stax.Dense`. + Based on :obj:`jax.example_libraries.stax.Dense`. Args: out_dim: @@ -751,14 +781,18 @@ def Dense( parameterization: Either `"ntk"` or `"standard"`. - Under `"ntk"` parameterization (https://arxiv.org/abs/1806.07572, page 3), + Under `"ntk"` parameterization (page 3 in "`Neural Tangent Kernel: + Convergence and Generalization in Neural Networks + `_"), weights and biases are initialized as :math:`W_{ij} \sim \mathcal{N}(0,1)`, :math:`b_i \sim \mathcal{N}(0,1)`, and the finite width layer equation is :math:`z_i = \sigma_W / \sqrt{N} \sum_j W_{ij} x_j + \sigma_b b_i`, where `N` is `out_dim`. - Under `"standard"` parameterization (https://arxiv.org/abs/2001.07301), + Under `"standard"` parameterization ("`On the infinite width limit of + neural networks with a standard parameterization + `_".), weights and biases are initialized as :math:`W_{ij} \sim \mathcal{N}(0, W_{std}^2/N)`, :math:`b_i \sim \mathcal{N}(0,\sigma_b^2)`, and the finite width layer @@ -766,22 +800,26 @@ def Dense( :math:`z_i = \frac{1}{s} \sum_j W_{ij} x_j + b_i`, where `N` is `out_dim`. `N` corresponds to the respective variable in - https://arxiv.org/abs/2001.07301. + "`On the infinite width limit of neural networks with a standard + parameterization `_". s: only applicable when `parameterization="standard"`. A tuple of integers specifying the width scalings of the input and the output of the layer, i.e. the weight matrix `W` of the layer has shape `(s[0] * in_dim, s[1] * out_dim)`, and the bias has size `s[1] * out_dim`. - Note that we need `s[0]` (scaling of the previous layer) to infer - `in_dim` from `input_shape`. Also note that for the bottom layer, `s[0]` - must be `1`, and for all other layers `s[0]` must be equal to `s[1]` of - the previous layer. For the top layer, `s[1]` is expected to be `1` - (recall that the output size is `s[1] * out_dim`, and in common infinite - network research input and output sizes are considered fixed). + + .. note:: + We need `s[0]` (scaling of the previous layer) to infer `in_dim` from + `input_shape`. Further, for the bottom layer, `s[0]` must be `1`, and + for all other layers `s[0]` must be equal to `s[1]` of the previous + layer. For the top layer, `s[1]` is expected to be `1` (recall that the + output size is `s[1] * out_dim`, and in common infinite network + research input and output sizes are considered fixed). `s` corresponds to the respective variable in - https://arxiv.org/abs/2001.07301. + "`On the infinite width limit of neural networks with a standard + parameterization `_". For `parameterization="ntk"`, or for standard, finite-width networks corresponding to He initialization, `s=(1, 1)`. @@ -893,9 +931,9 @@ def Conv( parameterization: str = 'ntk', s: Tuple[int, int] = (1, 1), ) -> InternalLayerMasked: - """Layer construction function for a general convolution layer. + """General convolution. - Based on `jax.example_libraries.stax.GeneralConv`. + Based on :obj:`jax.example_libraries.stax.GeneralConv`. Args: out_chan: @@ -953,9 +991,9 @@ def ConvTranspose( parameterization: str = 'ntk', s: Tuple[int, int] = (1, 1), ) -> InternalLayerMasked: - """Layer construction function for a general transpose convolution layer. + """General transpose convolution. - Based on `jax.example_libraries.stax.GeneralConvTranspose`. + Based on :obj:`jax.example_libraries.stax.GeneralConvTranspose`. Args: out_chan: @@ -1013,7 +1051,7 @@ def ConvLocal( parameterization: str = 'ntk', s: Tuple[int, int] = (1, 1), ) -> InternalLayerMasked: - """Layer construction function for a general unshared convolution layer. + """General unshared convolution. Also known and "Locally connected networks" or LCNs, these are equivalent to convolutions except for having separate (unshared) kernels at different @@ -1075,9 +1113,9 @@ def _Conv( transpose: bool, shared_weights: bool ) -> InternalLayerMasked: - """Layer construction function for a general convolution layer. + """General convolution. - Based on `jax.example_libraries.stax.GeneralConv`. + Based on :obj:`jax.example_libraries.stax.GeneralConv`. Args: out_chan: @@ -1183,7 +1221,7 @@ def ntk_init_fn(rng, input_shape): lax_conv = functools.partial(lax.conv_general_dilated_local, filter_shape=filter_shape) def ntk_init_fn(rng, input_shape): - """Adapted from `jax.example_libraries.stax.GeneralConv`.""" + """Adapted from :obj:`jax.example_libraries.stax.GeneralConv`.""" filter_shape_iter = iter(filter_shape) conv_kernel_shape = [out_chan if c == 'O' else input_shape[lhs_spec.index('C')] if c == 'I' else @@ -1385,7 +1423,9 @@ def mask_fn(mask, input_shape): rhs_shape.insert(rhs_spec.index(c), 1) # TODO(romann): revisit based on http://b/235531081. - rhs = np.ones(rhs_shape, dtype=None if jax.default_backend() == 'gpu' else mask.dtype) + rhs = np.ones( + rhs_shape, + dtype=None if jax.default_backend() == 'gpu' else mask.dtype) mask = lax.conv_transpose( mask.astype(rhs.dtype), rhs, @@ -1411,9 +1451,9 @@ def AvgPool(window_shape: Sequence[int], normalize_edges: bool = False, batch_axis: int = 0, channel_axis: int = -1) -> InternalLayerMasked: - """Layer construction function for an average pooling layer. + """Average pooling. - Based on `jax.example_libraries.stax.AvgPool`. + Based on :obj:`jax.example_libraries.stax.AvgPool`. Args: window_shape: The number of pixels over which pooling is to be performed. @@ -1445,9 +1485,9 @@ def SumPool(window_shape: Sequence[int], padding: str = Padding.VALID.name, batch_axis: int = 0, channel_axis: int = -1) -> InternalLayerMasked: - """Layer construction function for a 2D sum pooling layer. + """Sum pooling. - Based on `jax.example_libraries.stax.SumPool`. + Based on :obj:`jax.example_libraries.stax.SumPool`. Args: window_shape: The number of pixels over which pooling is to be performed. @@ -1476,10 +1516,10 @@ def _Pool( normalize_edges: bool, batch_axis: int, channel_axis: int) -> InternalLayerMasked: - """Layer construction function for a 2D pooling layer. + """General pooling. - Based on `jax.example_libraries.stax.AvgPool` and - `jax.example_libraries.stax.SumPool`. + Based on :obj:`jax.example_libraries.stax.AvgPool` and + :obj:`jax.example_libraries.stax.SumPool`. Args: pool_type: specifies whether average or sum pooling should be performed. @@ -1594,7 +1634,7 @@ def GlobalSumPool( batch_axis: int = 0, channel_axis: int = -1 ) -> InternalLayerMasked: - """Layer construction function for a global sum pooling layer. + """Global sum pooling. Sums over and removes (`keepdims=False`) all spatial dimensions, preserving the order of batch and channel axes. @@ -1618,7 +1658,7 @@ def GlobalAvgPool( batch_axis: int = 0, channel_axis: int = -1 ) -> InternalLayerMasked: - """Layer construction function for a global average pooling layer. + """Global average pooling. Averages over and removes (`keepdims=False`) all spatial dimensions, preserving the order of batch and channel axes. @@ -1641,7 +1681,7 @@ def _GlobalPool( batch_axis: int, channel_axis: int ) -> InternalLayerMasked: - """Layer construction function for a global pooling layer. + """General global pooling. Pools over and removes (`keepdims=False`) all spatial dimensions, preserving the order of batch and channel axes. @@ -1728,16 +1768,17 @@ def Flatten( batch_axis: int = 0, batch_axis_out: int = 0 ) -> InternalLayerMasked: - """Layer construction function for flattening all non-batch dimensions. + """Flattening all non-batch dimensions. - Based on `jax.example_libraries.stax.Flatten`, but allows to specify batch - axes. + Based on :obj:`jax.example_libraries.stax.Flatten`, but allows to specify + batch axes. Args: - batch_axis: Specifies the input batch dimension. Defaults to `0`, the - leading axis. - batch_axis_out: Specifies the output batch dimension. Defaults to `0`, the - leading axis. + batch_axis: + Specifies the input batch dimension. Defaults to `0`, the leading axis. + + batch_axis_out: + Specifies the output batch dimension. Defaults to `0`, the leading axis. Returns: `(init_fn, apply_fn, kernel_fn)`. @@ -1818,14 +1859,42 @@ def mask_fn(mask, input_shape): class PositionalEmbedding(enum.Enum): - """Type of positional embeddings to use in a `GlobalSelfAttention` layer.""" + """Type of positional embeddings to use in a :obj:`GlobalSelfAttention` layer. + + Attributes: + NONE: + no additional positional embeddings. + + SUM: + positional embeddings are added to activations. + + CONCAT: + positional embeddings are concatenated with activations. + """ NONE = 'NONE' SUM = 'SUM' CONCAT = 'CONCAT' class AttentionMechanism(enum.Enum): - """Type of nonlinearity to use in a `GlobalSelfAttention` layer.""" + """Type of nonlinearity to use in a :obj:`GlobalSelfAttention` layer. + + Attributes: + SOFTMAX: + attention weights are computed by passing the dot product between keys + and queries through :obj:`jax.nn.softmax`. + + IDENTITY: + attention weights are the dot product between keys and queries. + + ABS: + attention weights are computed by passing the dot product between keys + and queries through :obj:`jax.numpy.abs`. + + RELU: + attention weights are computed by passing the dot product between keys + and queries through :obj:`jax.nn.relu`. + """ SOFTMAX = 'SOFTMAX' IDENTITY = 'IDENTITY' ABS = 'ABS' @@ -1862,9 +1931,11 @@ def GlobalSelfAttention( val_pos_emb: bool = False, batch_axis: int = 0, channel_axis: int = -1) -> InternalLayerMasked: - """Layer construction function for (global) scaled dot-product self-attention. + """Global scaled dot-product self-attention. - Infinite width results based on https://arxiv.org/abs/2006.10540. + Infinite width results based on + "`Infinite attention: NNGP and NTK for deep attention networks + `_". Two versions of attention are available (the version to be used is determined by the argument `linear_scaling`): @@ -1884,37 +1955,47 @@ def GlobalSelfAttention( attention weights. The final computation for single head is then - :math:`f_h (x) + attention_mechanism( Q(x) K(x)^T) V(x)` + `f_h (x) + attention_mechanism( Q(x) K(x)^T) V(x)` and the output of this layer is computed as - :math:`f(x) = concat[f_1(x) , ... , f_{} (x)] W_{out} + b` + `f(x) = concat[f_1(x) , ... , f_{} (x)] W_{out} + b` where the shape of `b` is `(n_chan_out,)`, i.e., single bias per channel. The `kernel_fn` computes the limiting kernel of the outputs of this layer as the number of heads and the number of feature dimensions of keys/queries goes to infinity. + For details, please see "`Infinite attention: NNGP and NTK for deep attention + networks `_". + Args: n_chan_out: number of feature dimensions of outputs. + n_chan_key: number of feature dimensions of keys/queries. + n_chan_val: number of feature dimensions of values. + n_heads: number of attention heads. + linear_scaling: if `True`, the dot products between keys and queries are scaled by `1 / n_chan_key` and the key and query weight matrices are tied; if `False`, the dot products are scaled by `1 / sqrt(n_chan_key)` and the key and query matrices are independent. + W_key_std: init standard deviation of the key weights values. Due to NTK parameterization, influences computation only through the product `W_key_std * W_query_std`. + W_value_std: init standard deviation of the value weights values. Due to NTK parameterization, influences computation only through the product `W_out_std * W_value_std`. + W_query_std: init standard deviation of the query weights values; if `linear_scaling` is `True` (and thus key and query weights are tied - see above) then keys @@ -1922,15 +2003,19 @@ def GlobalSelfAttention( computed with `WQ = W_query_std * W / sqrt(n_chan_in)` weight matrices. Due to NTK parameterization, influences computation only through the product `W_key_std * W_query_std`. + W_out_std: initial standard deviation of the output weights values. Due to NTK parameterization, influences computation only through the product `W_out_std * W_value_std`. + b_std: initial standard deviation of the bias values. `None` means no bias. + attention_mechanism: a string, `"SOFTMAX"`, `"IDENTITY"`, `"ABS"`, or `"RELU"`, the transformation applied to dot product attention weights. + pos_emb_type: a string, `"NONE"`, `"SUM"`, or `"CONCAT"`, the type of positional embeddings to use. In the infinite-width limit, `"SUM"` and `"CONCAT"` @@ -1940,11 +2025,13 @@ def GlobalSelfAttention( which leads to different effective variances when using `"SUM"` and `"CONCAT"` embeddings, even if all variance scales like `W_key_std` etc. are the same. + pos_emb_p_norm: use the unnormalized L-`p` distance to the power of `p` (with `p == pos_emb_p_norm`) to compute pairwise distances for positional embeddings (see `pos_emb_decay_fn` for details). Used only if `pos_emb_type != "NONE"` and `pos_emb_decay_fn is not None`. + pos_emb_decay_fn: a function applied to the L-`p` distance to the power of `p` (with `p == pos_emb_p_norm`) distance between two spatial positions to produce @@ -1952,6 +2039,7 @@ def GlobalSelfAttention( exponential decay, etc.). `None` is equivalent to an indicator function `lambda d: d == 0`, and returns a diagonal covariance matrix. Used only if `pos_emb_type != "NONE"`. + n_chan_pos_emb: number of channels in positional embeddings. `None` means use the same number of channels as in the layer inputs. Can be used to tune the @@ -1959,6 +2047,7 @@ def GlobalSelfAttention( if `pos_emb_type == "CONCAT"`. Used only if `pos_emb_type != "NONE"`. Will trigger an error if `pos_emb_type == "SUM"` and `n_chan_pos_emb` is not `None` or does not match the layer inputs channel size at runtime. + W_pos_emb_std: init standard deviation of the random positional embeddings. Can be used to tune the contribution of positional embeddings relative to the @@ -1967,12 +2056,15 @@ def GlobalSelfAttention( `n_chan_pos_emb` when `pos_emb_type == "CONCAT"`, or, if `pos_emb_type == "CONCAT"`, adjust `W_key_std` etc. relative to `W_pos_emb_std`, to keep the total output variance fixed. + val_pos_emb: `True` indicates using positional embeddings when computing all of the keys/queries/values matrices, `False` makes them only used for keys and queries, but not values. Used only if `pos_emb_type != "NONE"`. + batch_axis: Specifies the batch dimension. Defaults to `0`, the leading axis. + channel_axis: Specifies the channel / feature dimension. Defaults to `-1`, the trailing axis. For `kernel_fn`, channel size is considered to be infinite. @@ -2443,14 +2535,15 @@ def prepare_mask(m): @layer @supports_masking(remask_kernel=False) def Dropout(rate: float, mode: str = 'train') -> InternalLayer: - """Dropout layer. + """Dropout. - Based on `jax.example_libraries.stax.Dropout`. + Based on :obj:`jax.example_libraries.stax.Dropout`. Args: rate: Specifies the keep `rate`, e.g. `rate=1` is equivalent to keeping all neurons. + mode: Either `"train"` or `"test"`. @@ -2504,7 +2597,7 @@ def ImageResize( batch_axis: int = 0, channel_axis: int = -1 ) -> InternalLayerMasked: - """Image resize function mimicking `jax.image.resize`. + """Image resize function mimicking :obj:`jax.image.resize`. Docstring adapted from https://jax.readthedocs.io/en/latest/_modules/jax/_src/image/scale.html#resize @@ -2521,7 +2614,7 @@ def ImageResize( are ignored. `ResizeMethod.LINEAR`, `"linear"`, `"bilinear"`, `"trilinear"`, `"triangle"`: - `Linear interpolation`_. If `antialias` is ``True``, uses a triangular + `Linear interpolation`_. If `antialias` is `True`, uses a triangular filter when downsampling. The following methods are NOT SUPPORTED in `kernel_fn` (only `init_fn` and @@ -2536,7 +2629,8 @@ def ImageResize( `ResizeMethod.LANCZOS5`, `"lanczos5"`: `Lanczos resampling`_, using a kernel of radius 5. - .. _Nearest neighbor interpolation: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation + .. _Nearest neighbor interpolation: + https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling @@ -2548,10 +2642,12 @@ def ImageResize( distinguish spatial dimensions from batch or channel dimensions, so this includes all dimensions of the image. To leave a certain dimension (e.g. batch or channel) unchanged, set the respective entry to `-1`. - Note that setting it to the respective size of the `input` also works, - but will make `kernel_fn` computation much more expensive with no benefit. - Further, note that `kernel_fn` does not support resizing the - `channel_axis`, therefore `shape[channel_axis]` should be set to `-1`. + + .. note:: + Setting a `shape` entry to the respective size of the `input` also + works, but will make `kernel_fn` computation much more expensive with + no benefit. Further, note that `kernel_fn` does not support resizing the + `channel_axis`, therefore `shape[channel_axis]` should be set to `-1`. method: the resizing method to use; either a `ResizeMethod` instance or a @@ -2590,30 +2686,33 @@ def apply_fn(params, x, **kwargs): precision=precision) def mask_fn(mask, input_shape): - # Interpolation (except for "NEAREST") is done in float format: - # https://github.com/google/jax/issues/3811. Float converted back to bool - # rounds up all non-zero elements to `True`, so naively resizing the `mask` - # will mark any output that has at least one contribution from a masked - # input as fully masked. This can lead to mask growing unexpectedly, e.g. - # consider a 5x5 image with a single masked pixel in the center: - # - # >>> mask = np.array([[0, 0, 0, 0, 0], - # >>> [0, 0, 0, 0, 0], - # >>> [0, 0, 1, 0, 0], - # >>> [0, 0, 0, 0, 0], - # >>> [0, 0, 0, 0, 0]], dtype=np.bool_) - # - # Downsampling this mask to 2x2 will mark all output pixels as masked! - # - # >>> jax.image.resize(mask, (2, 2), method='bilinear').astype(np.bool_) - # >>> DeviceArray([[ True, True], - # >>> [ True, True]], dtype=bool) - # - # Therefore, througout `stax` we rather follow the convention of marking - # outputs as masked if they _only_ have contributions from masked elements - # (in other words, we don't let the mask destroy information; let content - # have preference over mask). For this we invert the mask before and after - # resizing, to round up unmasked outputs instead. + """Behavior of interpolation with masking. + + Interpolation (except for "NEAREST") is done in float format: + https://github.com/google/jax/issues/3811. Float converted back to bool + rounds up all non-zero elements to `True`, so naively resizing the `mask` + will mark any output that has at least one contribution from a masked + input as fully masked. This can lead to mask growing unexpectedly, e.g. + consider a 5x5 image with a single masked pixel in the center: + + >>> mask = np.array([[0, 0, 0, 0, 0], + >>> [0, 0, 0, 0, 0], + >>> [0, 0, 1, 0, 0], + >>> [0, 0, 0, 0, 0], + >>> [0, 0, 0, 0, 0]], dtype=np.bool_) + + Downsampling this mask to 2x2 will mark all output pixels as masked! + + >>> jax.image.resize(mask, (2, 2), method='bilinear').astype(np.bool_) + DeviceArray([[ True, True], + [ True, True]], dtype=bool) + + Therefore, througout `stax` we rather follow the convention of marking + outputs as masked if they _only_ have contributions from masked elements + (in other words, we don't let the mask destroy information; let content + have preference over mask). For this we invert the mask before and after + resizing, to round up unmasked outputs instead. + """ return ~jax.image.resize(image=~mask, shape=_shape(mask.shape), method=method, @@ -2732,21 +2831,32 @@ def _same_pad_for_filter_shape( axes: Sequence[int], mode: str = 'wrap', ) -> np.ndarray: - """Pad an array to imitate `SAME` padding with `VALID`. + """Padding imitating :attr:`Padding.SAME` padding with :attr:`Padding.VALID`. See `Returns` section for details. This function is usually needed to - implement `CIRCULAR` padding using `VALID` padding. + implement :attr:`Padding.CIRCULAR` padding using :attr:`Padding.VALID` + padding. Args: - x: `np.ndarray` to pad, e.g. a 4D `NHWC` image. - filter_shape: tuple of positive integers, the convolutional filters spatial - shape (e.g. `(3, 3)` for a 2D convolution). - strides: tuple of positive integers, the convolutional spatial strides, e.g. - e.g. `(1, 1)` for a 2D convolution. - axes: tuple of non-negative integers, the spatial axes to apply - convolution over (e.g. `(1, 2)` for an `NHWC` image). - mode: a string, padding mode, for all options see + x: + `np.ndarray` to pad, e.g. a 4D `NHWC` image. + + filter_shape: + tuple of positive integers, the convolutional filters spatial shape (e.g. + `(3, 3)` for a 2D convolution). + + strides: + tuple of positive integers, the convolutional spatial strides, e.g. + `(1, 1)` for a 2D convolution. + + axes: + tuple of non-negative integers, the spatial axes to apply convolution + over (e.g. `(1, 2)` for an `NHWC` image). + + mode: + a string, padding mode, for all options see https://docs.scipy.org/doc/numpy/reference/generated/numpy.pad.html. + Returns: A `np.ndarray` of the same dimensionality as `x` padded to a potentially larger shape such that a `"VALID"` convolution with `filter_shape` applied @@ -3082,7 +3192,8 @@ def _conv_kernel_full_spatial_loop( filter_shape: Sequence[int], strides: Sequence[int], padding: Padding, - lax_conv: Callable, + lax_conv: Callable[ + [np.ndarray, np.ndarray, Tuple[int, ...], str], np.ndarray], get_n_channels: Callable[[int], int] ) -> np.ndarray: padding = Padding.VALID if padding == Padding.CIRCULAR else padding @@ -3285,7 +3396,7 @@ def _pool_kernel( def _normalize(lhs, out, normalize_edges, padding, strides, window_shape): if padding == Padding.SAME and normalize_edges: - # `SAME` padding in `jax.example_libraries.stax.AvgPool` normalizes by + # `SAME` padding in :obj:`jax.example_libraries.stax.AvgPool` normalizes by # actual window size, which is smaller at the edges. one = np.ones_like(lhs, lhs.dtype) window_sizes = lax.reduce_window(one, 0., lax.add, window_shape, strides, @@ -3473,9 +3584,10 @@ def _pool_mask( def _pooling_layer(reducer, init_val, rescaler=None): - """Adapted from `jax.example_libraries.stax`.""" + """Adapted from :obj:`jax.example_libraries.stax`.""" + def PoolingLayer(window_shape, strides=None, padding='VALID', spec=None): - """Layer construction function for a pooling layer.""" + """Pooling.""" window_shape = tuple(window_shape) strides = strides or (1,) * len(window_shape) rescale = rescaler(window_shape, strides, padding) if rescaler else None @@ -3496,11 +3608,13 @@ def init_fun(rng, input_shape): out_shape = lax.reduce_window_shape_tuple( input_shape, window_shape, strides, padding_vals, ones, ones) return out_shape, () + def apply_fun(params, inputs, **kwargs): out = lax.reduce_window(inputs, init_val, reducer, window_shape, strides, padding) return rescale(out, inputs, spec) if rescale else out return init_fun, apply_fun + return PoolingLayer diff --git a/neural_tangents/_src/stax/requirements.py b/neural_tangents/_src/stax/requirements.py index f581228f..df22bee3 100644 --- a/neural_tangents/_src/stax/requirements.py +++ b/neural_tangents/_src/stax/requirements.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Requirement management for `stax` layers.""" +"""Requirement management for :obj:`~neural_tangents.stax` layers.""" import enum from typing import Callable, Optional, Tuple, Union, Sequence @@ -35,20 +35,23 @@ def layer(layer_fn: Callable[..., InternalLayer]) -> Callable[..., Layer]: - """A convenience decorator to be added to all public layers like `Relu` etc. + """A convenience decorator to be added to all public layers. - Makes the `kernel_fn` of the layer work with both input `np.ndarray` - (when the layer is the first one applied to inputs), and with `Kernel` for - intermediary layers. Also adds optional arguments to the `kernel_fn` to - allow specifying the computation and returned results with more flexibility. + Used in :obj:`~neural_tangents.stax.Relu` etc. + + Makes the `kernel_fn` of the layer work with both input + :class:`jax.numpy.ndarray` (when the layer is the first one applied to + inputs), and with :class:`~neural_tangents.Kernel` for intermediary layers. + Also adds optional arguments to the `kernel_fn` to allow specifying the + computation and returned results with more flexibility. Args: layer_fn: Layer function returning triple `(init_fn, apply_fn, kernel_fn)`. Returns: A function with the same signature as `layer` with `kernel_fn` now - accepting `np.ndarray` as inputs if needed, and accepts optional `get`, - `diagonal_batch`, `diagonal_spatial` arguments. + accepting :class:`jax.numpy.ndarray` as inputs if needed, and accepts + optional `get`, `diagonal_batch`, `diagonal_spatial` arguments. """ name = layer_fn.__name__ @@ -68,7 +71,7 @@ def requires(**static_reqs): Use this to specify your `kernel_fn` input kernel requirements. See Also: - `stax.Diagonal`, `stax.Input`, `stax.Output`. + :class:`Diagonal`, :class:`Input`, :class:`Output`. """ @@ -145,8 +148,9 @@ def supports_masking(remask_kernel: bool): Must be applied before the `layer` decorator. Args: - remask_kernel: `True` to zero-out kernel covariance entries between masked - inputs after applying `kernel_fn`. Some layers don't need this and setting + remask_kernel: + `True` to zero-out kernel covariance entries between masked inputs after + applying `kernel_fn`. Some layers don't need this and setting `remask_kernel=False` can save compute. Returns: @@ -255,7 +259,19 @@ def _has_req(f: Callable) -> bool: class Bool(enum.IntEnum): - """Helper trinary logic class.""" + """Helper trinary logic class. See :class:`Diagonal` for details. + + Attributes: + NO: + `False`. + + MAYBE: + Maybe. + + YES: + `True`. + + """ NO = 0 MAYBE = 1 YES = 2 @@ -273,7 +289,8 @@ class Diagonal: The intended behavior is to be diagonal-only iff a) output off-diagonal entries are all zeros, and - b) diagonal-only `Kernel` is sufficient for all steps of computation. + b) diagonal-only :class:`~neural_tangents.Kernel` is sufficient for all + steps of computation. Note that currently this parameter is shared between all parallel branches, even if this is excessive, and it is defined once for the whole network and @@ -289,16 +306,16 @@ class Diagonal: Attributes: input: specifies whether inputs to given layer can contain only diagonal - entries. `_Bool.YES` means "yes"; `_Bool.MAYBE` means iff off-diagonal - entries are zero. `_Bool.NO` means "no". When traversing the network - tree from inputs to outputs (as well as parallel branches from left/right - to right/left) can only decrease. + entries. :attr:`Bool.YES` means "yes"; :attr:`Bool.MAYBE` means iff + off-diagonal entries are zero. :attr:`Bool.NO` means "no". When + traversing the network tree from inputs to outputs (as well as parallel + branches from left/right to right/left) can only decrease. output: specifies whether any outputs (starting from this layer to the output of - the network) can contain only diagonal entries. `_Bool.YES` means yes; - `_Bool.MAYBE` means "yes" after current layer, but may become "no" - further in the network. `_Bool.NO` means "no". + the network) can contain only diagonal entries. :attr:`Bool.YES` means + yes; :attr:`Bool.MAYBE` means "yes" after current layer, but may become + "no" further in the network. :attr:`Bool.NO` means "no". """ input: Bool = Bool.YES @@ -423,14 +440,19 @@ def _cov( """Computes uncentered covariance (nngp) between two batches of inputs. Args: - x1: a (2+S)D (S >= 0) `np.ndarray` of shape + x1: + a (2+S)D (S >= 0) `np.ndarray` of shape `(batch_size_1, , n_channels)`. `batch_size_1`, `n_channels` may be in different positions based on `batch_axis` and `channel_axis`. - x2: an optional `np.ndarray` that has the same shape as `a` apart from + + x2: + an optional `np.ndarray` that has the same shape as `a` apart from possibly different batch (`batch_size_2`) dimension. `None` means `x2 == x1`. - diagonal_spatial: Specifies whether only the diagonals of the + + diagonal_spatial: + Specifies whether only the diagonals of the location-location covariances will be computed, (`diagonal_spatial == True`, `nngp.shape == (batch_size_1, batch_size_2, height, width, depth, ...)`), @@ -438,9 +460,14 @@ def _cov( (`diagonal_spatial == False`, `nngp.shape == (batch_size_1, batch_size_2, height, height, width, width, depth, depth, ...)`). - batch_axis: Specifies which axis is the batch axis. - channel_axis: Specifies which axis is the channel / feature axis. - For `kernel_fn`, channel size is considered to be infinite. + + batch_axis: + Specifies which axis is the batch axis. + + channel_axis: + Specifies which axis is the channel / feature axis. For `kernel_fn`, + channel size is considered to be infinite. + Returns: Matrix of uncentred batch covariances with shape `(batch_size_1, batch_size_2, )` @@ -490,42 +517,45 @@ def _inputs_to_kernel( Example: >>> x = np.ones((10, 32, 16, 3)) >>> o = _inputs_to_kernel(x, None, - >>> diagonal_batch=True, - >>> diagonal_spatial=False, - >>> compute_ntk=True, - >>> batch_axis=0, - >>> channel_axis=-1) + >>> diagonal_batch=True, + >>> diagonal_spatial=False, + >>> compute_ntk=True, + >>> batch_axis=0, + >>> channel_axis=-1) >>> o.cov1.shape, o.ntk.shape (10, 32, 32, 16, 16), (10, 10, 32, 32, 16, 16) >>> o = _inputs_to_kernel(x, None, - >>> diagonal_batch=True, - >>> diagonal_spatial=True, - >>> compute_ntk=True, - >>> batch_axis=0, - >>> channel_axis=-1) + >>> diagonal_batch=True, + >>> diagonal_spatial=True, + >>> compute_ntk=True, + >>> batch_axis=0, + >>> channel_axis=-1) >>> o.cov1.shape, o.ntk.shape (10, 32, 16), (10, 10, 32, 16) >>> x1 = np.ones((10, 128)) >>> x2 = np.ones((20, 128)) >>> o = _inputs_to_kernel(x1, x2, - >>> diagonal_batch=True, - >>> diagonal_spatial=True, - >>> compute_ntk=False, - >>> batch_axis=0, - >>> channel_axis=-1) + >>> diagonal_batch=True, + >>> diagonal_spatial=True, + >>> compute_ntk=False, + >>> batch_axis=0, + >>> channel_axis=-1) >>> o.cov1.shape, o.nngp.shape (10,), (10, 20) Args: - x1: an `(S+2)`-dimensional `np.ndarray` of shape + x1: + an `(S+2)`-dimensional `np.ndarray` of shape `(batch_size_1, height, width, depth, ..., n_channels)` with `S` spatial dimensions (`S >= 0`). Dimensions may be in different order based on `batch_axis` and `channel_axis`. - x2: an optional `np.ndarray` with the same shape as `x1` apart - from possibly different batch size. `None` means `x2 == x1`. + x2: + an optional `np.ndarray` with the same shape as `x1` apart from possibly + different batch size. `None` means `x2 == x1`. - diagonal_batch: Specifies whether `cov1` and `cov2` store only + diagonal_batch: + Specifies whether `cov1` and `cov2` store only the diagonal of the sample-sample covariance (`diagonal_batch == True`, `cov1.shape == (batch_size_1, ...)`), @@ -533,9 +563,9 @@ def _inputs_to_kernel( (`diagonal_batch == False`, `cov1.shape == (batch_size_1, batch_size_1, ...)`). - diagonal_spatial: Specifies whether all (`cov1`, `ntk`, etc.) - input covariance matrcies should store only the diagonals of the - location-location covariances + diagonal_spatial: + Specifies whether all (`cov1`, `ntk`, etc.) input covariance matrcies + should store only the diagonals of the location-location covariances (`diagonal_spatial == True`, `nngp.shape == (batch_size_1, batch_size_2, height, width, depth, ...)`), or the full covariance @@ -543,27 +573,31 @@ def _inputs_to_kernel( `nngp.shape == (batch_size_1, batch_size_2, height, height, width, width, depth, depth, ...)`). - compute_ntk: `True` to compute both NTK and NNGP kernels, - `False` to only compute NNGP. + compute_ntk: + `True` to compute both NTK and NNGP kernels, `False` to only compute NNGP. - batch_axis: Specifies which axis is the batch axis. + batch_axis: + Specifies which axis is the batch axis. - channel_axis: Specifies which axis is the channel / feature axis. - For `kernel_fn`, channel size is considered to be infinite. + channel_axis: + Specifies which axis is the channel / feature axis. For `kernel_fn`, + channel size is considered to be infinite. - mask_constant: an optional `float`, the value in inputs to be considered as - masked (e.g. padding in a batch of sentences). `None` means no masking. - Can also be `np.nan`, `np.inf` etc. Beware of floating point precision - errors and try to use an atypical for inputs value. + mask_constant: + an optional `float`, the value in inputs to be considered as masked (e.g. + padding in a batch of sentences). `None` means no masking. Can also be + `np.nan`, `np.inf` etc. Beware of floating point precision errors and try + to use an atypical for inputs value. - eps: a small number used to check whether x1 and x2 are the same up to - `eps`. + eps: + a small number used to check whether x1 and x2 are the same up to `eps`. - **kwargs: other arguments passed to all intermediary `kernel_fn` calls (not - used here). + **kwargs: + other arguments passed to all intermediary `kernel_fn` calls (not used + here). Returns: - The `Kernel` object containing inputs covariance[s]. + The :class:`~neural_tangents.Kernel` object containing inputs covariance[s]. """ if not (isinstance(x1, (onp.ndarray, np.ndarray)) and diff --git a/neural_tangents/_src/utils/dataclasses.py b/neural_tangents/_src/utils/dataclasses.py index f1bb1d16..5590daae 100644 --- a/neural_tangents/_src/utils/dataclasses.py +++ b/neural_tangents/_src/utils/dataclasses.py @@ -32,23 +32,24 @@ def dataclass(clz): Jax transformations such as `jax.jit` and `jax.grad` require objects that are immutable and can be mapped over using the `jax.tree_util` functions. The `dataclass` decorator makes it easy to define custom classes that can be - passed safely to Jax. For example: - - >>> from jax import jit, numpy as np - >>> from neural_tangents._src.utils import dataclasses - >>> - >>> @dataclasses.dataclass - >>> class Data: - >>> array: np.ndarray - >>> a_boolean: bool = dataclasses.field(pytree_node=False) - >>> - >>> data = Data(np.array([1.0]), True) - >>> - >>> data.array = np.array([2.0]) # Data is immutable. Will raise an error. - >>> data = data.replace(array=np.array([2.0])) # Use the replace method. - >>> - >>> # This class can now be used safely in Jax. - >>> jit(lambda data: data.array if data.a_boolean else 0)(data) + passed safely to Jax. + + Example: + >>> from jax import jit, numpy as np + >>> from neural_tangents._src.utils import dataclasses + >>> # + >>> @dataclasses.dataclass + >>> class Data: + >>> array: np.ndarray + >>> a_boolean: bool = dataclasses.field(pytree_node=False) + >>> # + >>> data = Data(np.array([1.0]), True) + >>> # + >>> data.array = np.array([2.0]) # Data is immutable. Will raise an error. + >>> data = data.replace(array=np.array([2.0])) # Use the replace method. + >>> # + >>> # This class can now be used safely in Jax. + >>> jit(lambda data: data.array if data.a_boolean else 0)(data) Args: clz: the class that will be transformed by the decorator. diff --git a/neural_tangents/_src/utils/kernel.py b/neural_tangents/_src/utils/kernel.py index bca922cb..bea526b1 100644 --- a/neural_tangents/_src/utils/kernel.py +++ b/neural_tangents/_src/utils/kernel.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""The `Kernel` class with infinite-width NTK and NNGP `np.ndarray` fields.""" - +"""Class with infinite-width NTK and NNGP :class:`jax.numpy.ndarray` fields.""" import operator as op from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple @@ -86,7 +85,8 @@ class Kernel: (`diagonal_spatial == True`, `nngp.shape == (batch_size_1, batch_size_2, height, width, depth, ...)`), or the full covariance (`diagonal_spatial == False`, `nngp.shape == - (batch_size_1, batch_size_2, height, height, width, width, depth, depth, ...)`). + (batch_size_1, batch_size_2, height, height, width, width, depth, depth, + ...)`). Defaults to `False`, but is set to `True` if the output top-layer covariance depends only on the diagonals (e.g. when a CNN network has no pooling layers and `Flatten` on top). diff --git a/neural_tangents/_src/utils/rules.py b/neural_tangents/_src/utils/rules.py index ba0010cb..d6c3eb20 100644 --- a/neural_tangents/_src/utils/rules.py +++ b/neural_tangents/_src/utils/rules.py @@ -381,11 +381,12 @@ def _conv_general_dilated_s( if (rhs.shape[rhs_spec[0]] == feature_group_count and rhs.shape[rhs_spec[1]] == 1): + assert lhs.shape[lhs_spec[1]] == feature_group_count return Structure( in_trace=(), in_trace_idxs=(), out_trace=(), - in_diagonal=((None, rhs_spec[0]),), + in_diagonal=((lhs_spec[1], rhs_spec[0]),), out_diagonal=(out_spec[1],) ) @@ -504,16 +505,19 @@ def _conv_general_dilated_e( # `conv_general_dilated` has `lhs_shape` and `rhs_shape` arguments that are # for some reason not inferred from the `lhs` and `rhs` themselves. # TODO(romann): ask JAX why these are there. - if idx == 0: - params['lhs_shape'] = trimmed_invals[0].shape - params['rhs_shape'] = trimmed_invals[1].shape + dn = params['dimension_numbers'] - elif idx == 1: - params['lhs_shape'] = trimmed_invals[0].shape - params['rhs_shape'] = trimmed_invals[1].shape + if (params['feature_group_count'] == params['lhs_shape'][dn[0][1]] and + params['feature_group_count'] == params['rhs_shape'][dn[1][0]]): + params['feature_group_count'] = 1 - else: - raise ValueError(f'Convolution only has two inputs, got input index {idx}.') + if (params['batch_group_count'] == params['rhs_shape'][dn[1][0]] and + params['batch_group_count'] == params['lhs_shape'][dn[0][0]]): + params['batch_group_count'] = 1 + + lhs, rhs = trimmed_invals + params['lhs_shape'] = lhs.shape + params['rhs_shape'] = rhs.shape return params diff --git a/neural_tangents/_src/utils/typing.py b/neural_tangents/_src/utils/typing.py index 0cd51743..1b88c1a2 100644 --- a/neural_tangents/_src/utils/typing.py +++ b/neural_tangents/_src/utils/typing.py @@ -48,18 +48,18 @@ network computations (for example, when neural networks have nested parallel layers). - Mimicking JAX, we use a lightweight tree structure called an `NTTree`. - `NTTree` has internal nodes that are either lists or tuples and leaves which - are either `np.ndarray` or `Kernel` objects. + Mimicking JAX, we use a lightweight tree structure called an :class:`NTTree`. + :class:`NTTree` has internal nodes that are either lists or tuples and leaves + which are either :class:`jax.numpy.ndarray` or :class:`~neural_tangents.Kernel` objects. """ NTTrees = Union[List[T], Tuple[T, ...]] - """A list or tuple of `NTTree` s. + """A list or tuple of :class:`NTTree` s. """ Shapes = NTTree[Tuple[int, ...]] -"""A shape - a tuple of integers, or an `NTTree` of such tuples. +"""A shape - a tuple of integers, or an :class:`NTTree` of such tuples. """ @@ -139,10 +139,11 @@ def __call__( class AnalyticKernelFn(Protocol): """A type alias for analytic kernel functions. - A kernel function that computes an analytic kernel. Takes either a `Kernel` - or `np.ndarray` inputs and a `get` argument that specifies what quantities - should be computed by the kernel. Returns either a `Kernel` object or - `np.ndarray`-s for kernels specified by `get`. + A kernel function that computes an analytic kernel. Takes either a + :class:`~neural_tangents.Kernel` or :class:`jax.numpy.ndarray` inputs and a + `get` argument that specifies what quantities should be computed by the + kernel. Returns either a :class:`~neural_tangents.Kernel` object or + :class:`jax.numpy.ndarray`-s for kernels specified by `get`. """ def __call__( diff --git a/neural_tangents/_src/utils/utils.py b/neural_tangents/_src/utils/utils.py index a6428228..9a881db3 100644 --- a/neural_tangents/_src/utils/utils.py +++ b/neural_tangents/_src/utils/utils.py @@ -20,7 +20,6 @@ import operator import types from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Sized, Tuple, Type, TypeVar, Union -from .typing import Axes, PyTree import warnings from . import dataclasses @@ -33,6 +32,12 @@ import numpy as onp +PyTree = Any + + +Axes = Union[int, Sequence[int]] + + def is_list_or_tuple(x) -> bool: # We do not want to return True if x is a subclass of list or tuple since # otherwise this will return true for namedtuples. @@ -47,9 +52,10 @@ def is_nt_tree_of(x, dtype: Union[Type, Tuple[Type, ...]]) -> bool: return all(is_nt_tree_of(_x, dtype) for _x in x) -def nt_tree_fn(nargs: Optional[int] = None, - tree_structure_argnum: Optional[int] = None, - reduce: Callable = lambda x: x): +def nt_tree_fn( + nargs: Optional[int] = None, + tree_structure_argnum: Optional[int] = None, + reduce: Callable = lambda x: x): """Convert a function that acts on single inputs to one that acts on trees. `nt_tree_fn` treats the first `nargs` arguments as NTTrees and the remaining @@ -62,14 +68,19 @@ def nt_tree_fn(nargs: Optional[int] = None, is used to infer the structure. Args: - nargs: The number of arguments to be treated as NTTrees. If `nargs` is None + nargs: + The number of arguments to be treated as NTTrees. If `nargs` is `None` then all of the arguments are used. `nargs` can also be negative which follows numpy's semantics for array indexing. - tree_structure_argnum: The argument used to infer the tree structure to be - traversed. If `tree_structure_argnum` is None then a check is performed to - ensure that all trees have the same structure. - reduce: A callable that is applied recursively by each internal tree node - to its children. + + tree_structure_argnum: + The argument used to infer the tree structure to be traversed. If + `tree_structure_argnum` is None then a check is performed to ensure that + all trees have the same structure. + + reduce: + A callable that is applied recursively by each internal tree node to its + children. Returns: A decorator `tree_fn` that transforms a function, `fn`, from acting on @@ -625,9 +636,10 @@ def axis_after_dot(axis: int, ) -def make_2d(x: Optional[np.ndarray], - start_axis: int = 0, - end_axis: Optional[int] = None) -> Optional[np.ndarray]: +def make_2d( + x: Optional[np.ndarray], + start_axis: int = 0, + end_axis: Optional[int] = None) -> Optional[np.ndarray]: """Makes `x` 2D from `start_axis` to `end_axis`, preserving other axes. `x` is assumed to follow the (`X, X, Y, Y, Z, Z`) axes layout. @@ -635,9 +647,9 @@ def make_2d(x: Optional[np.ndarray], Example: >>> x = np.ones((1, 2, 3, 3, 4, 4)) >>> make_2d(x).shape == (12, 24) - >>> + >>> # >>> make_2d(x, 2).shape == (1, 2, 12, 12) - >>> + >>> # >>> make_2d(x, 2, 4).shape == (1, 2, 3, 3, 4, 4) """ if x is None: diff --git a/neural_tangents/experimental/empirical_tf/empirical.py b/neural_tangents/experimental/empirical_tf/empirical.py index a18568ec..53b0479f 100644 --- a/neural_tangents/experimental/empirical_tf/empirical.py +++ b/neural_tangents/experimental/empirical_tf/empirical.py @@ -14,79 +14,87 @@ """Experimental prototype of empirical NTK computation in Tensorflow. -This module is applicable to `tf.Module`, `tf.keras.Model`, or `tf.function` -functions, subject to some conditions (see docstring of `empirical_ntk_fn_tf`). +This module is applicable to :class:`tf.Module`, :class:`tf.keras.Model`, or +:obj:`tf.function` functions, subject to some conditions (see docstring of +:obj:`empirical_ntk_fn_tf`). -The kernel function follows the API of `neural_tangents.empirical_ntk_fn`. +The kernel function follows the API of :obj:`neural_tangents.empirical_ntk_fn`. Please read the respective docstring for more details. +.. warning:: + This module currently appears to have long compile times (but OK runtime), + is prone to triggering XLA errors, and does not distinguish between trainable + and non-trainable parameters of the model. + +For details about the empirical (finite width) NTK computation, please see +"`Fast Finite Width Neural Tangent Kernel `_". + Example: - >>> import tensorflow as tf - >>> from tensorflow.keras import layers - >>> import neural_tangents as nt - >>> - >>> x_train = tf.random.normal((20, 32, 32, 3)) - >>> x_test = tf.random.normal((5, 32, 32, 3)) - >>> - >>> # A CNN. - >>> f = tf.keras.Sequential() - >>> f.add(layers.Conv2D(32, (3, 3), activation='relu', - >>> input_shape=x_train.shape[1:])) - >>> f.add(layers.Conv2D(32, (3, 3), activation='relu')) - >>> f.add(layers.Conv2D(32, (3, 3))) - >>> f.add(layers.Flatten()) - >>> f.add(layers.Dense(10)) - >>> - >>> f.build((None, *x_train.shape[1:])) - >>> _, params = nt.experimental.get_apply_fn_and_params(f) - >>> - >>> # Default setting: reducing over logits (default `trace_axes=(-1,)`; - >>> # pass `vmap_axes=0` because the network is iid along the batch axis, no - >>> # BatchNorm. - >>> kernel_fn = nt.experimental.empirical_ntk_fn_tf(f, vmap_axes=0) - >>> - >>> # (5, 20) tf.Tensor test-train NTK - >>> nngp_test_train = kernel_fn(x_test, x_train, params) - >>> ntk_test_train = kernel_fn(x_test, x_train, params) - >>> - >>> # Full kernel: not reducing over logits. - >>> kernel_fn = nt.experimental.empirical_ntk_fn_tf(f, trace_axes=(), - >>> vmap_axes=0) - >>> - >>> # (5, 20, 10, 10) tf.Tensor test-train NTK. - >>> k_test_train = kernel_fn(x_test, x_train, params) - >>> - >>> # An FCN - >>> f = tf.keras.Sequential() - >>> f.add(layers.Flatten()) - >>> f.add(layers.Dense(1024, activation='relu')) - >>> f.add(layers.Dense(1024, activation='relu')) - >>> f.add(layers.Dense(10)) - >>> - >>> f.build((None, *x_train.shape[1:])) - >>> _, params = nt.experimental.get_apply_fn_and_params(f) - >>> - >>> # Use ntk-vector products since the network has many parameters - >>> # relative to the cost of forward pass. - >>> ntk_fn = nt.experimental.empirical_ntk_fn_tf(f, vmap_axes=0, - >>> implementation=2) - >>> - >>> # (5, 5) tf.Tensor test-test NTK - >>> ntk_test_test = ntk_fn(x_test, None, params) - >>> - >>> # Compute only NTK diagonal variances: - >>> ntk_fn = nt.experimental.empirical_ntk_fn_tf(f, diagonal_axes=(0,)) - >>> - >>> # (20,) tf.Tensor train-train NTK diagonal - >>> ntk_train_train_diag = ntk_fn(x_train, None, params) + >>> import tensorflow as tf + >>> from tensorflow.keras import layers + >>> import neural_tangents as nt + >>> # + >>> x_train = tf.random.normal((20, 32, 32, 3)) + >>> x_test = tf.random.normal((5, 32, 32, 3)) + >>> # + >>> # A CNN. + >>> f = tf.keras.Sequential() + >>> f.add(layers.Conv2D(32, (3, 3), activation='relu', + >>> input_shape=x_train.shape[1:])) + >>> f.add(layers.Conv2D(32, (3, 3), activation='relu')) + >>> f.add(layers.Conv2D(32, (3, 3))) + >>> f.add(layers.Flatten()) + >>> f.add(layers.Dense(10)) + >>> # + >>> f.build((None, *x_train.shape[1:])) + >>> _, params = nt.experimental.get_apply_fn_and_params(f) + >>> # + >>> # Default setting: reducing over logits (default `trace_axes=(-1,)`; + >>> # pass `vmap_axes=0` because the network is iid along the batch axis, no + >>> # BatchNorm. + >>> kernel_fn = nt.experimental.empirical_ntk_fn_tf(f, vmap_axes=0) + >>> # + >>> # (5, 20) tf.Tensor test-train NTK + >>> nngp_test_train = kernel_fn(x_test, x_train, params) + >>> ntk_test_train = kernel_fn(x_test, x_train, params) + >>> # + >>> # Full kernel: not reducing over logits. + >>> kernel_fn = nt.experimental.empirical_ntk_fn_tf(f, trace_axes=(), + >>> vmap_axes=0) + >>> # + >>> # (5, 20, 10, 10) tf.Tensor test-train NTK. + >>> k_test_train = kernel_fn(x_test, x_train, params) + >>> # + >>> # An FCN + >>> f = tf.keras.Sequential() + >>> f.add(layers.Flatten()) + >>> f.add(layers.Dense(1024, activation='relu')) + >>> f.add(layers.Dense(1024, activation='relu')) + >>> f.add(layers.Dense(10)) + >>> # + >>> f.build((None, *x_train.shape[1:])) + >>> _, params = nt.experimental.get_apply_fn_and_params(f) + >>> # + >>> # Use ntk-vector products since the network has many parameters + >>> # relative to the cost of forward pass. + >>> ntk_fn = nt.experimental.empirical_ntk_fn_tf(f, vmap_axes=0, + >>> implementation=2) + >>> # + >>> # (5, 5) tf.Tensor test-test NTK + >>> ntk_test_test = ntk_fn(x_test, None, params) + >>> # + >>> # Compute only NTK diagonal variances: + >>> ntk_fn = nt.experimental.empirical_ntk_fn_tf(f, diagonal_axes=(0,)) + >>> # + >>> # (20,) tf.Tensor train-train NTK diagonal + >>> ntk_train_train_diag = ntk_fn(x_train, None, params) """ from typing import Callable, Optional, Union import warnings from jax.experimental import jax2tf -import neural_tangents as nt -from neural_tangents._src.empirical import DEFAULT_NTK_IMPLEMENTATION, _DEFAULT_NTK_FWD, _DEFAULT_NTK_J_RULES, _DEFAULT_NTK_S_RULES +from neural_tangents._src.empirical import NtkImplementation, empirical_ntk_fn, DEFAULT_NTK_IMPLEMENTATION, _DEFAULT_NTK_FWD, _DEFAULT_NTK_J_RULES, _DEFAULT_NTK_S_RULES from neural_tangents._src.utils.typing import Axes, NTTree, PyTree, VMapAxes import tensorflow as tf import tf2jax @@ -97,37 +105,43 @@ def empirical_ntk_fn_tf( trace_axes: Axes = (-1,), diagonal_axes: Axes = (), vmap_axes: VMapAxes = None, - implementation: Union[ - nt.NtkImplementation, int] = DEFAULT_NTK_IMPLEMENTATION, + implementation: Union[NtkImplementation, int] = DEFAULT_NTK_IMPLEMENTATION, _j_rules: bool = _DEFAULT_NTK_J_RULES, _s_rules: bool = _DEFAULT_NTK_S_RULES, _fwd: Optional[bool] = _DEFAULT_NTK_FWD, ) -> Callable[..., NTTree[tf.Tensor]]: r"""Returns a function to draw a single sample the NTK of a given network `f`. - This function follows the API of `neural_tangents.empirical_ntk_fn`, but is - applicable to Tensorflow `tf.function` or `tf.keras.Model`, via a TF->JAX->TF - roundtrip using `tf2jax` and `jax2tf`. Docstring below adapted from - `neural_tangents.empirical_ntk_fn`. + This function follows the API of :obj:`neural_tangents.empirical_ntk_fn`, but + is applicable to Tensorflow :class:`tf.Module`, :class:`tf.keras.Model`, or + :obj:`tf.function`, via a TF->JAX->TF roundtrip using `tf2jax` and `jax2tf`. + Docstring below adapted from :obj:`neural_tangents.empirical_ntk_fn`. - WARNING: this function is highly experimental and risks returning wrong - results or performing slowly. It is intended to demonstrate the usage of - `neural_tangents.empirical_ntk_fn` in Tensorflow, but has not been - extensively tested. + .. warning:: + This function is experimental and risks returning wrong results or + performing slowly. It is intended to demonstrate the usage of + :obj:`neural_tangents.empirical_ntk_fn` in Tensorflow, but has not been + extensively tested. Specifically, it appears to have very long + compile times (but OK runtime), is prone to triggering XLA errors, and does + not distinguish between trainable and non-trainable parameters of the model. - TODO(romann): support proper division between trainable and non-trainable - variables. + TODO(romann): support division between trainable and non-trainable variables. TODO(romann): investigate slow compile times. Args: f: - `tf.Module` or `tf.function` whose NTK we are computing. Must + :class:`tf.Module` or :obj:`tf.function` whose NTK we are computing. Must satisfy the following: - - if a `tf.function`, must have the signature of `f(params, x)`. - - if a `tf.Module`, must be either a `tf.keras.Model`, or be callable. - - input signature (`f.input_shape` for `tf.Module` or `tf.keras.Model`, - or `f.input_signature` for `tf.function`) must be known. + + - if a :obj:`tf.function`, must have the signature of `f(params, x)`. + + - if a :class:`tf.Module`, must be either a :class:`tf.keras.Model`, or + be callable. + + - input signature (`f.input_shape` for :class:`tf.Module` or + :class:`tf.keras.Model`, or `f.input_signature` for `tf.function`) + must be known. trace_axes: output axes to trace the output kernel over, i.e. compute only the trace @@ -180,31 +194,34 @@ def empirical_ntk_fn_tf( set to `None`, to avoid wrong (and potentially silent) results. implementation: - Applicable only to NTK, an `NtkImplementation` value (or an integer `0`, - `1`, `2`, or `3`). See the `NtkImplementation` enum docstring for details. + An :class:`NtkImplementation` value (or an :class:`int` `0`, `1`, `2`, or + `3`). See the :class:`NtkImplementation` docstring for details. _j_rules: - Internal debugging parameter, applicable only to NTK when - `implementation` is `STRUCTURED_DERIVATIVES` (`3`) or `AUTO` (`0`). Set to - `True` to allow custom Jacobian rules for intermediary primitive `dy/dw` - computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to - `False` to use JVPs or VJPs, via JAX's `jacfwd` or `jacrev`. Custom - Jacobian rules (`True`) are expected to be not worse, and sometimes better - than automated alternatives, but in case of a suboptimal implementation - setting it to `False` could improve performance. + Internal debugging parameter, applicable only when + `implementation` is :attr:`NtkImplementation.STRUCTURED_DERIVATIVES` + (`3`) or :attr:`NtkImplementation.AUTO` (`0`). Set to `True` to allow + custom Jacobian rules for intermediary primitive `dy/dw` computations for + MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to `False` to use + JVPs or VJPs, via JAX's :obj:`jax.jacfwd` or :obj:`jax.jacrev`. Custom + Jacobian rules (`True`) are expected to be not worse, and sometimes + better than automated alternatives, but in case of a suboptimal + implementation setting it to `False` could improve performance. _s_rules: - Internal debugging parameter, applicable only to NTK when - `implementation` is `STRUCTURED_DERIVATIVES` (`3`) or `AUTO` (`0`). Set to - `True` to allow efficient MJJMp rules for structured `dy/dw` primitive - Jacobians. In practice should be set to `True`, and setting it to `False` - can lead to dramatic deterioration of performance. + Internal debugging parameter, applicable only when `implementation` is + :attr:`NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) or + :attr:`NtkImplementation.AUTO` (`0`). Set to `True` to allow efficient + MJJMp rules for structured `dy/dw` primitive Jacobians. In practice + should be set to `True`, and setting it to `False` can lead to dramatic + deterioration of performance. _fwd: - Internal debugging parameter, applicable only to NTK when - `implementation` is `STRUCTURED_DERIVATIVES` (`3`) or `AUTO` (`0`). Set to - `True` to allow `jvp` in intermediary primitive Jacobian `dy/dw` - computations, `False` to always use `vjp`. `None` to decide automatically + Internal debugging parameter, applicable only when `implementation` is + :attr:`NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) or + :attr:`NtkImplementation.AUTO` (`0`). Set to `True` to allow + :obj:`jax.jvp` in intermediary primitive Jacobian `dy/dw` computations, + `False` to always use :obj:`jax.vjp`. `None` to decide automatically based on input/output sizes. Applicable when `_j_rules=False`, or when a primitive does not have a Jacobian rule. Should be set to `None` for best performance. @@ -234,22 +251,27 @@ def empirical_ntk_fn_tf( f'please file a bug at ' f'https://github.com/google/neural-tangents.') - ntk_fn = nt.empirical_ntk_fn(apply_fn, **kwargs) + ntk_fn = empirical_ntk_fn(apply_fn, **kwargs) ntk_fn = jax2tf.convert(ntk_fn) ntk_fn = tf.function(ntk_fn, jit_compile=True, autograph=False) return ntk_fn def get_apply_fn_and_params(f: tf.Module): - """Converts a `tf.Module` into a forward-pass `apply_fn` and `params`. + """Converts a :class:`tf.Module` into a forward-pass `apply_fn` and `params`. Use this function to extract `params` to pass to the Tensorflow empirical NTK kernel function. + .. warning:: + This function does not distinguish between trainable and non-trainable + parameters of the model. + Args: - f: a `tf.Module` to convert to a `apply_fn(params, x)` function. Must have - an `input_shape` attribute set (specifying shape of `x`), and be callable - or be a `tf.keras.Model`. + f: + a :class:`tf.Module` to convert to a `apply_fn(params, x)` function. Must + have an `input_shape` attribute set (specifying shape of `x`), and be + callable or be a :class:`tf.keras.Model`. Returns: A tuple fo `(apply_fn, params)`, where `params` is an `NTTree[tf.Tensor]`. diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py index 57248ef3..5d6ea18d 100644 --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -14,53 +14,62 @@ """Closed-form NNGP and NTK library. -This library contains layer constructors mimicking those in -`jax.example_libraries.stax` with similar API apart apart from: +This library contains layers mimicking those in +:obj:`jax.example_libraries.stax` with similar API apart from: -1) Instead of `(init_fn, apply_fn)` tuple, layer constructors return a triple +1) Instead of `(init_fn, apply_fn)` tuple, layers return a triple `(init_fn, apply_fn, kernel_fn)`, where the added `kernel_fn` maps a -`Kernel` to a new `Kernel`, and represents the change in the -analytic NTK and NNGP kernels (`Kernel.nngp`, `Kernel.ntk`). These functions -are chained / stacked together within the `serial` or `parallel` -combinators, similarly to `init_fn` and `apply_fn`. +:class:`~neural_tangents..Kernel` to a new :class:`~neural_tangents.Kernel`, and +represents the change in the analytic NTK and NNGP kernels +(:attr:`~neural_tangents.Kernel.nngp`, :attr:`~neural_tangents.Kernel.ntk`). +These functions are chained / stacked together within the :obj:`serial` or +:obj:`parallel` combinators, similarly to `init_fn` and `apply_fn`. +For details, please see "`Neural Tangents: Fast and Easy Infinite Neural +Networks in Python `_". 2) In layers with random weights, NTK parameterization is used by default -(https://arxiv.org/abs/1806.07572, page 3). Standard parameterization -(https://arxiv.org/abs/2001.07301) can be specified for `Conv` and `Dense` -layers by a keyword argument `parameterization`. - -3) Some functionality may be missing (e.g. `BatchNorm`), and some may be -present only in our library (e.g. `CIRCULAR` padding, `LayerNorm`, -`GlobalAvgPool`, `GlobalSelfAttention`, flexible batch and channel axes etc.). +(see page 3 in +"`Neural Tangent Kernel: Convergence and Generalization in Neural Networks +`_"). Standard parameterization can be +specified for :obj:`Conv` and :obj:`Dense` layers by a keyword argument +`parameterization`. For details, please see "`On the infinite width limit of +neural networks with a standard parameterization +`_". + +3) Some functionality may be missing (e.g. +:obj:`jax.example_libraries.stax.BatchNorm`), and some may be +present only in our library (e.g. :attr:`Padding.CIRCULAR` padding, +:obj:`LayerNorm`, :obj:`GlobalAvgPool`, :obj:`GlobalSelfAttention`, flexible +batch and channel axes etc.). Example: - >>> from jax import random - >>> import neural_tangents as nt - >>> from neural_tangents import stax - >>> - >>> key1, key2 = random.split(random.PRNGKey(1), 2) - >>> x_train = random.normal(key1, (20, 32, 32, 3)) - >>> y_train = random.uniform(key1, (20, 10)) - >>> x_test = random.normal(key2, (5, 32, 32, 3)) - >>> - >>> init_fn, apply_fn, kernel_fn = stax.serial( - >>> stax.Conv(128, (3, 3)), - >>> stax.Relu(), - >>> stax.Conv(256, (3, 3)), - >>> stax.Relu(), - >>> stax.Conv(512, (3, 3)), - >>> stax.Flatten(), - >>> stax.Dense(10) - >>> ) - >>> - >>> predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, - >>> y_train) - >>> - >>> # (5, 10) np.ndarray NNGP test prediction - >>> y_test_nngp = predict_fn(x_test=x_test, get='nngp') - >>> - >>> # (5, 10) np.ndarray NTK prediction - >>> y_test_ntk = predict_fn(x_test=x_test, get='ntk') + >>> from jax import random + >>> import neural_tangents as nt + >>> from neural_tangents import stax + >>> # + >>> key1, key2 = random.split(random.PRNGKey(1), 2) + >>> x_train = random.normal(key1, (20, 32, 32, 3)) + >>> y_train = random.uniform(key1, (20, 10)) + >>> x_test = random.normal(key2, (5, 32, 32, 3)) + >>> # + >>> init_fn, apply_fn, kernel_fn = stax.serial( + >>> stax.Conv(128, (3, 3)), + >>> stax.Relu(), + >>> stax.Conv(256, (3, 3)), + >>> stax.Relu(), + >>> stax.Conv(512, (3, 3)), + >>> stax.Flatten(), + >>> stax.Dense(10) + >>> ) + >>> # + >>> predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, + >>> y_train) + >>> # + >>> # (5, 10) np.ndarray NNGP test prediction + >>> y_test_nngp = predict_fn(x_test=x_test, get='nngp') + >>> # + >>> # (5, 10) np.ndarray NTK prediction + >>> y_test_ntk = predict_fn(x_test=x_test, get='ntk') """ @@ -125,6 +134,8 @@ # Enums to specify layer behavior. from ._src.stax.linear import ( + AggregateImplementation, + AttentionMechanism, Padding, PositionalEmbedding, ) diff --git a/notebooks/empirical_ntk_resnet.ipynb b/notebooks/empirical_ntk_resnet.ipynb index a715f4a2..6b3596e8 100644 --- a/notebooks/empirical_ntk_resnet.ipynb +++ b/notebooks/empirical_ntk_resnet.ipynb @@ -359,6 +359,7 @@ } ], "source": [ + "#@test {\"skip\": true}\n", "# Structured derivatives\n", "k_3 = ntk_fn_str_derivatives(x1, x2, params)\n", "print(k_3.shape)" diff --git a/notebooks/experimental/empirical_ntk_resnet_tf.ipynb b/notebooks/experimental/empirical_ntk_resnet_tf.ipynb index 62d8fd7c..9f33945b 100644 --- a/notebooks/experimental/empirical_ntk_resnet_tf.ipynb +++ b/notebooks/experimental/empirical_ntk_resnet_tf.ipynb @@ -16,7 +16,8 @@ "id": "nTt0UNQbk_Td" }, "source": [ - "# Example of computing NTK of a **Tensorflow (Keras)** ResNet50 on ImageNet inputs" + "# Example of computing NTK of a **Tensorflow (Keras)** ResNet50 on ImageNet inputs\n", + "Warning: computing the NTK in Tensorflow currently appears to have very long compile times (but OK runtime), can be prone to triggering XLA errors, and does not distinguish between trainable and non-trainable parameters of the model." ] }, { diff --git a/setup.py b/setup.py index 8ab74f44..6861b15e 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,6 @@ """Setup the package with pip.""" - import os import setuptools @@ -30,17 +29,14 @@ 'jax>=0.3.13', 'frozendict>=2.3', 'typing_extensions>=4.0.1', - 'tf2jax @ git+https://github.com/deepmind/tf2jax', + 'tf2jax>=0.3.0', ] TESTS_REQUIRES = [ 'more-itertools', 'tensorflow-datasets', - 'flax>=0.5.1', - # TODO(romann): remove when - # https://github.com/google/flax/issues/2190 is fixed. - 'PyYAML>=6.0' + 'flax>=0.5.2', ] @@ -72,14 +68,16 @@ def _get_version() -> str: author_email='neural-tangents-dev@google.com', install_requires=INSTALL_REQUIRES, extras_require={ - "testing": TESTS_REQUIRES, + 'testing': TESTS_REQUIRES, }, url='https://github.com/google/neural-tangents', download_url='https://pypi.org/project/neural-tangents/', project_urls={ 'Source Code': 'https://github.com/google/neural-tangents', 'Paper': 'https://arxiv.org/abs/1912.02803', + 'Finite Width NTK paper': 'https://arxiv.org/abs/2206.08720', 'Video': 'https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html', + 'Finite Width NTK video': 'https://youtu.be/8MWOhYg89fY?t=10984', 'Documentation': 'https://neural-tangents.readthedocs.io/en/latest/?badge=latest', 'Bug Tracker': 'https://github.com/google/neural-tangents/issues', 'Release Notes': 'https://github.com/google/neural-tangents/releases', diff --git a/tests/empirical_test.py b/tests/empirical_test.py index d092ab11..73be4327 100644 --- a/tests/empirical_test.py +++ b/tests/empirical_test.py @@ -17,7 +17,7 @@ from functools import partial import logging import operator -from typing import Any, Callable, Sequence, Tuple, Optional, Dict +from typing import Any, Callable, Sequence, Tuple, Optional, Dict, List from absl.testing import absltest from absl.testing import parameterized from flax import linen as nn @@ -374,19 +374,23 @@ def layer(N_out): _, params = init_fn(net_key, (x1_1.shape, x1_2.shape)) ntk_fns = { - i: jit(partial(nt.empirical_ntk_fn( - apply_fn, - implementation=i), - params=params)) + i: jit( + partial( + nt.empirical_ntk_fn( + apply_fn, + implementation=i), + params=params)) for i in nt.NtkImplementation } ntk_fns_vmapped = { - i: jit(partial(nt.empirical_ntk_fn( - apply_fn, - implementation=i, - vmap_axes=(0, 0)), - params=params)) + i: jit( + partial( + nt.empirical_ntk_fn( + apply_fn, + implementation=i, + vmap_axes=(0, 0)), + params=params)) for i in nt.NtkImplementation } @@ -603,7 +607,7 @@ def get_x(n, k): 'p[0] @ p[1] @ p[2]': lambda p, x: p[0] @ p[1] @ p[2], 'p[0].T @ p[0]': lambda p, x: p[0].T @ p[0], 'p[1].T @ p[0]': lambda p, x: p[1].T @ p[0], - 'p[2] @ p[0] @ p[1]' : lambda p, x: p[2] @ p[0] @ p[1], + 'p[2] @ p[0] @ p[1]': lambda p, x: p[2] @ p[0] @ p[1], '(p[0] @ p[1], p[0])': lambda p, x: (p[0] @ p[1], p[0]), '(p[0] @ p[1], p[1])': lambda p, x: (p[0] @ p[1], p[1]), '(p[0] @ p[1], p[1].T)': lambda p, x: (p[0] @ p[1], p[1].T), @@ -801,6 +805,7 @@ def _compare_ntks( _j_rules, _s_rules, _fwd, + vmap_axes=None, allow_forward_pass_fail=False, rtol=None, atol=None, @@ -827,6 +832,7 @@ def _compare_ntks( f=f, trace_axes=(), implementation=i, + vmap_axes=vmap_axes, _j_rules=_j_rules, _s_rules=_s_rules, _fwd=_fwd @@ -1002,7 +1008,7 @@ def test_function( # TODO(romann): investigate slow CPU execution. test_utils.skip_test('Skipping large non-structured reshapes on CPU.') - if 'lax.map' in f_name and len(shapes[0][0]) > 0 and shapes[0][0][0] == 0: + if 'lax.map' in f_name and shapes[0][0] and shapes[0][0][0] == 0: # TODO(romann): fix. raise absltest.SkipTest('Zero-length scans not supported without JIT.') @@ -1064,12 +1070,17 @@ def __call__(self, x): class _CNN(nn.Module): + features: int + feature_group_counts: List[int] + @nn.compact def __call__(self, x): - x = nn.Conv(features=32, kernel_size=(3, 3))(x) + x = nn.Conv(features=self.features, kernel_size=(3, 3), + feature_group_count=self.feature_group_counts[0])(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nn.Conv(features=64, kernel_size=(3, 3))(x) + x = nn.Conv(features=self.features, kernel_size=(3, 3), + feature_group_count=self.feature_group_counts[1])(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten @@ -1109,7 +1120,7 @@ class _Encoder(nn.Module): @nn.compact def __call__(self, x): - x = nn.Dense(500, name='fc1')(x) + x = nn.Dense(32, name='fc1')(x) x = nn.relu(x) mean_x = nn.Dense(self.latents, name='fc2_mean')(x) logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x) @@ -1120,9 +1131,9 @@ class _Decoder(nn.Module): @nn.compact def __call__(self, z): - z = nn.Dense(500, name='fc1')(z) + z = nn.Dense(16, name='fc1')(z) z = nn.relu(z) - z = nn.Dense(784, name='fc2')(z) + z = nn.Dense(32, name='fc2')(z) return z @@ -1306,12 +1317,12 @@ def __call__(self, inputs, *, train): def _get_mixer_b16_config() -> Dict[str, Any]: """Returns a narrow Mixer-B/16 configuration.""" return dict( - model_name='Mixer-B_16', - patches={'size': (16, 16)}, - hidden_dim=16, - num_blocks=2, - tokens_mlp_dim=4, - channels_mlp_dim=8, + model_name='Mixer-B_16', + patches={'size': (16, 16)}, + hidden_dim=16, + num_blocks=2, + tokens_mlp_dim=4, + channels_mlp_dim=8, ) @@ -1364,7 +1375,7 @@ def _get_mixer_b16_config() -> Dict[str, Any]: for dtype in [ jax.dtypes.canonicalize_dtype(np.float64), ])) -class FlaxTest(test_utils.NeuralTangentsTestCase): +class FlaxOtherTest(test_utils.NeuralTangentsTestCase): def test_mlp(self, same_inputs, do_jit, do_remat, dtype, j_rules, s_rules, fwd): @@ -1372,24 +1383,12 @@ def test_mlp(self, same_inputs, do_jit, do_remat, dtype, j_rules, k1, k2, ki = random.split(random.PRNGKey(1), 3) - x1 = random.normal(k1, (32, 10), dtype) + x1 = random.normal(k1, (4, 10), dtype) x2 = None if same_inputs else random.normal(k2, (3, 10), dtype) p = model.init(ki, x1) _compare_ntks(self, do_jit, do_remat, model.apply, p, x1, x2, j_rules, - s_rules, fwd) - - def test_cnn(self, same_inputs, do_jit, do_remat, dtype, j_rules, - s_rules, fwd): - test_utils.skip_test(self) - - model = _CNN() - x1 = random.normal(random.PRNGKey(1), (5, 8, 8, 3), dtype) - x2 = None if same_inputs else random.normal(random.PRNGKey(2), (2, 8, 8, 3), - dtype) - p = model.init(random.PRNGKey(0), x1) - _compare_ntks(self, do_jit, do_remat, model.apply, p, x1, x2, j_rules, - s_rules, fwd) + s_rules, fwd, vmap_axes=0) def test_autoencoder(self, same_inputs, do_jit, do_remat, dtype, j_rules, s_rules, fwd): @@ -1407,14 +1406,14 @@ def test_autoencoder(self, same_inputs, do_jit, do_remat, dtype, j_rules, # Test encoding-decoding. _compare_ntks(self, do_jit, do_remat, model.apply, p, x1, x2, j_rules, - s_rules, fwd) + s_rules, fwd, vmap_axes=0) # Test encoding. def encode(p, x): return model.apply(p, x, method=model.encode) _compare_ntks(self, do_jit, do_remat, encode, p, x1, x2, j_rules, - s_rules, fwd) + s_rules, fwd, vmap_axes=0) # Test decoding. x1d = model.apply(p, x1, method=model.encode) @@ -1424,7 +1423,7 @@ def decode(p, x): return model.apply(p, x, method=model.decode) _compare_ntks(self, do_jit, do_remat, decode, p, x1d, x2d, j_rules, - s_rules, fwd) + s_rules, fwd, vmap_axes=0) # Test manual encoding-decoding def encode_decode(p, x): @@ -1434,13 +1433,13 @@ def encode_decode(p, x): # Test encoding-decoding. _compare_ntks(self, do_jit, do_remat, encode_decode, p, x1, x2, j_rules, - s_rules, fwd) + s_rules, fwd, vmap_axes=0) def test_vae(self, same_inputs, do_jit, do_remat, dtype, j_rules, s_rules, fwd): test_utils.skip_test(self) - model = _VAE(latents=1) + model = _VAE(latents=2) k1, k2, ki, kzi, kza = random.split(random.PRNGKey(1), 5) x1 = random.normal(k1, (1, 1), dtype) x2 = None if same_inputs else random.normal(k2, (1, 1), dtype) @@ -1486,5 +1485,167 @@ def apply_fn(params, x): s_rules, fwd) +@parameterized.product( + j_rules=[ + True, + False + ], + s_rules=[ + True, + False + ], + fwd=[ + True, + False, + None, + ], + same_inputs=[ + # True, + False + ], + do_jit=[ + True, + # False + ], + do_remat=[ + # True, + False + ], + dtype=[ + jax.dtypes.canonicalize_dtype(np.float64), + ], + feature_group_counts=[ + [1, 1], + [1, 5], + [5, 1], + [5, 5] + ], +) +class FlaxCnnTest(test_utils.NeuralTangentsTestCase): + + def test_flax_cnn(self, same_inputs, do_jit, do_remat, dtype, j_rules, + s_rules, fwd, feature_group_counts): + test_utils.skip_test(self) + n_chan = 5 + x1 = random.normal(random.PRNGKey(1), (2, 8, 8, n_chan), dtype) + x2 = None if same_inputs else random.normal(random.PRNGKey(2), + (3, 8, 8, n_chan), + dtype) + model = _CNN(n_chan, feature_group_counts) + p = model.init(random.PRNGKey(0), x1) + _compare_ntks(self, do_jit, do_remat, model.apply, p, x1, x2, j_rules, + s_rules, fwd, vmap_axes=0) + + +@parameterized.product( + j_rules=[ + True, + False + ], + s_rules=[ + True, + False + ], + fwd=[ + True, + False, + None, + ], + same_inputs=[ + # True, + False + ], + do_jit=[ + True, + # False + ], + do_remat=[ + # True, + False + ], + dtype=[ + jax.dtypes.canonicalize_dtype(np.float64), + ], + n_chan_in=[ + 1, + 2, + 3, + 4 + ], + batch_size=[ + 1, + 2, + 3, + 4 + ], + group_count=[ + 1, + 2, + 4, + 8, + 16, + ], + group_mode=[ + 'batch', + 'feature' + ], + vmap_axes=[ + 0, + None + ] +) +class ConvTest(test_utils.NeuralTangentsTestCase): + + def test_conv( + self, + same_inputs, + do_jit, + do_remat, + dtype, + j_rules, + s_rules, + fwd, + n_chan_in, + batch_size, + group_count, + group_mode, + vmap_axes + ): + # TODO(b/235167364): unskip when the bug is fixed. + test_utils.skip_test(self, platforms=('cpu', 'tpu',)) + + n_chan_out = 16 + + if group_mode == 'batch': + batch_group_count = group_count + feature_group_count = 1 + if vmap_axes == 0 and group_count > 1: + raise absltest.SkipTest('Batch grouped convolution not vmap-able.') + + elif group_mode == 'feature': + batch_group_count = 1 + feature_group_count = group_count + + else: + raise ValueError(group_mode) + + n_chan_in *= feature_group_count + batch_size *= batch_group_count + + x1 = random.normal(random.PRNGKey(1), (batch_size, n_chan_in, 5, 4), dtype) + x2 = None if same_inputs else random.normal(random.PRNGKey(2), + (batch_size, n_chan_in, 5, 4), + dtype) + p = random.normal(random.PRNGKey(2), + (n_chan_out, n_chan_in // feature_group_count, 3, 2)) + def f(p, x): + return lax.conv_general_dilated(x, p, (1, 1), 'SAME', + feature_group_count=feature_group_count, + batch_group_count=batch_group_count) + + _compare_ntks(self, do_jit, do_remat, f, p, x1, x2, j_rules, s_rules, fwd, + vmap_axes=vmap_axes) + + if __name__ == '__main__': absltest.main() diff --git a/tests/experimental/empirical_tf_test.py b/tests/experimental/empirical_tf_test.py index a986ec02..23dd3e3f 100644 --- a/tests/experimental/empirical_tf_test.py +++ b/tests/experimental/empirical_tf_test.py @@ -16,6 +16,8 @@ from absl.testing import absltest from absl.testing import parameterized +import jax +from jax import numpy as np import neural_tangents as nt from neural_tangents import experimental import numpy as onp @@ -25,26 +27,6 @@ tf.random.set_seed(1) -_input_signature = [tf.TensorSpec((1, 2, 1, 4)), - tf.TensorSpec((None, 2, 3, 1))] - - -def _f1(params, x): - return x * tf.reduce_mean(params**2) + 1. - - -def _f2(params, x): - return tf.reduce_mean(x) * params**2 + 1. - - -def _f3(params, x): - return _f1(params, _f1(params, x)) + tf.reduce_sum(_f2(params, x)) - - -def _f4(params, x): - return _f1(params, x) + tf.reduce_sum(_f2(params, _f3(params, x))) - - # TF module copied from https://www.tensorflow.org/api_docs/python/tf/Module @@ -57,7 +39,7 @@ def __init__(self, input_dim, output_size, name=None): self.b = tf.Variable(tf.zeros([1, output_size]), name='b') def __call__(self, x): - y = tf.matmul(x, self.w) + self.b + y = tf.matmul(x, self.w) / x.shape[-1]**0.5 + self.b return tf.nn.relu(y) @@ -79,11 +61,137 @@ def __call__(self, x): return x +# Functions to compare TF/JAX manually. + + +_input_signature = [tf.TensorSpec((1, 2, 1, 4)), + tf.TensorSpec((None, 2, 3, 1))] + + +def _f1(params, x): + return x * tf.reduce_mean(params**2) + 1. + + +def _f1_jax(params, x): + return x * np.mean(params**2) + 1. + + +def _f2(params, x): + return tf.reduce_mean(x) * params**2 + 1. + + +def _f2_jax(params, x): + return np.mean(x) * params**2 + 1. + + +def _f3(params, x): + return _f1(params, _f1(params, x)) + tf.reduce_mean(_f2(params, x)) + + +def _f3_jax(params, x): + return _f1_jax(params, _f1_jax(params, x)) + np.mean(_f2_jax(params, x)) + + +def _f4(params, x): + return _f1(params, x) + tf.reduce_mean(_f2(params, _f3(params, x))) + + +def _f4_jax(params, x): + return _f1_jax(params, x) + np.mean(_f2_jax(params, _f3_jax(params, x))) + + +# ResNet18 adapted from +# https://github.com/jimmyyhwu/resnet18-tf2/blob/master/resnet.py + + +_kaiming_normal = tf.keras.initializers.VarianceScaling( + scale=2.0, mode='fan_out', distribution='untruncated_normal') + + +def _conv3x3(x, out_planes, stride=1, name=None): + x = tf.keras.layers.ZeroPadding2D(padding=1, name=f'{name}_pad')(x) + return tf.keras.layers.Conv2D( + filters=out_planes, kernel_size=3, strides=stride, use_bias=False, + kernel_initializer=_kaiming_normal, name=name)(x) + + +def _basic_block(x, planes, stride=1, downsample=None, name=None): + identity = x + + out = _conv3x3(x, planes, stride=stride, name=f'{name}.conv1') + out = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, + name=f'{name}.bn1')(out) + out = tf.keras.layers.ReLU(name=f'{name}.relu1')(out) + + out = _conv3x3(out, planes, name=f'{name}.conv2') + out = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, + name=f'{name}.bn2')(out) + + if downsample is not None: + for layer in downsample: + identity = layer(identity) + + out = tf.keras.layers.Add(name=f'{name}.add')([identity, out]) + out = tf.keras.layers.ReLU(name=f'{name}.relu2')(out) + + return out + + +def _make_layer(x, planes, blocks, stride=1, name=None): + downsample = None + inplanes = x.shape[3] + if stride != 1 or inplanes != planes: + downsample = [ + tf.keras.layers.Conv2D( + filters=planes, kernel_size=1, strides=stride, + use_bias=False, kernel_initializer=_kaiming_normal, + name=f'{name}.0.downsample.0'), + tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, + name=f'{name}.0.downsample.1'), + ] + + x = _basic_block(x, planes, stride, downsample, name=f'{name}.0') + for i in range(1, blocks): + x = _basic_block(x, planes, name=f'{name}.{i}') + + return x + + +def _resnet(x, blocks_per_layer, classes, filters): + x = tf.keras.layers.ZeroPadding2D(padding=3, name='conv1_pad')(x) + x = tf.keras.layers.Conv2D( + filters=filters, kernel_size=7, strides=2, use_bias=False, + kernel_initializer=_kaiming_normal, name='conv1')(x) + x = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5, + name='bn1')(x) + x = tf.keras.layers.ReLU(name='relu1')(x) + x = tf.keras.layers.ZeroPadding2D(padding=1, name='maxpool_pad')(x) + x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, name='maxpool')(x) + + x = _make_layer(x, filters, blocks_per_layer[0], name='layer1') + x = _make_layer(x, 2 * filters, blocks_per_layer[1], stride=2, name='layer2') + + x = tf.keras.layers.GlobalAveragePooling2D(name='avgpool')(x) + initializer = tf.keras.initializers.RandomUniform(-1.0 / (2 * filters)**0.5, + 1.0 / (2 * filters)**0.5) + x = tf.keras.layers.Dense(units=classes, kernel_initializer=initializer, + bias_initializer=initializer, name='fc')(x) + + return x + + +def _MiniResNet(classes, input_shape, weights): + inputs = tf.keras.Input(shape=input_shape) + outputs = _resnet(inputs, [1, 1, 1, 1], classes=classes, filters=2) + return tf.keras.Model(inputs=inputs, outputs=outputs) + + class EmpiricalTfTest(parameterized.TestCase): def _compare_ntks( self, f, + f_jax, params, trace_axes, diagonal_axes, @@ -93,13 +201,20 @@ def _compare_ntks( raise absltest.SkipTest('Overlapping trace and diagonal axes.') kwargs = dict( - f=f, trace_axes=trace_axes, diagonal_axes=diagonal_axes, ) + jax_ntk_fns = [ + jax.jit(nt.empirical_ntk_fn( + **kwargs, f=f_jax, implementation=i, vmap_axes=v)) + for i in nt.NtkImplementation + for v in vmap_axes if v not in trace_axes + diagonal_axes + ] + ntk_fns = [ experimental.empirical_ntk_fn_tf(**kwargs, + f=f, implementation=i, vmap_axes=v) for i in nt.NtkImplementation @@ -109,22 +224,46 @@ def _compare_ntks( x_shape = (f.input_shape[1:] if isinstance(f, tf.Module) else f.input_signature[1].shape[1:]) - x1 = tf.random.normal((2,) + x_shape, seed=2) - x2 = tf.random.normal((3,) + x_shape, seed=3) + x1 = tf.random.normal((2,) + x_shape, seed=2) / onp.prod(x_shape)**0.5 + x2 = tf.random.normal((3,) + x_shape, seed=3) / onp.prod(x_shape)**0.5 + + x1_jax = np.array(x1) + x2_jax = np.array(x2) + params_jax = jax.tree_map(lambda x: np.array(x), params) + + jax_ntks = [ntk_fn_i(x1_jax, x2_jax, params_jax) + for ntk_fn_i in jax_ntk_fns] ntks = list(enumerate([ntk_fn_i(x1, x2, params) for ntk_fn_i in ntk_fns])) + if len(tf.config.list_physical_devices()) > 1: # TPU + atol = 0. + rtol = 5e-3 + atol_jax = 0.4 + rtol_jax = 0.15 # TODO(romann): revisit poor TPU agreement. + else: + atol = 1e-5 + rtol = 1e-5 + atol_jax = 0. + rtol_jax = 1e-5 + for i1, ntk1 in ntks: for i2, ntk2 in ntks[i1 + 1:]: - onp.testing.assert_allclose(ntk1, ntk2, rtol=3e-5, atol=1e-5) + # Compare different implementation + onp.testing.assert_allclose(ntk1, ntk2, rtol=rtol, atol=atol) + # Compare against the JAX version (without calling `jax2tf`). + onp.testing.assert_allclose(ntk1, jax_ntks[i1], rtol=rtol_jax, + atol=atol_jax) @parameterized.product( f=[ - tf.keras.applications.MobileNet, + _MiniResNet, + # # TODO(romann): MobileNet works, but takes too long to compile. + # tf.keras.applications.MobileNet, ], input_shape=[ - (32, 32, 3) + (64, 64, 3) ], trace_axes=[ (), @@ -146,18 +285,14 @@ def test_keras_functional( diagonal_axes, vmap_axes, ): - if len(tf.config.list_physical_devices()) != 2: - # TODO(romann): file bugs on enormous compile time on CPU and TPU. - raise absltest.SkipTest('Skipping CPU and TPU keras tests.') - f = f(classes=1, input_shape=input_shape, weights=None) f.build((None, *input_shape)) - _, params = experimental.get_apply_fn_and_params(f) - self._compare_ntks(f, params, trace_axes, diagonal_axes, vmap_axes) + f_jax, params = experimental.get_apply_fn_and_params(f) + self._compare_ntks(f, f_jax, params, trace_axes, diagonal_axes, vmap_axes) @parameterized.product( input_shape=[ - (16, 16, 3) + (32, 32, 3) ], trace_axes=[ (), @@ -179,8 +314,7 @@ def test_keras_sequential( vmap_axes, ): if tf.config.list_physical_devices('GPU') and diagonal_axes: - # TODO(romann): figure out the `XlaRuntimeError`. - raise absltest.SkipTest('RET_CHECK failure') + raise absltest.SkipTest('http://b/237035658') f = tf.keras.Sequential() f.add(tf.keras.layers.Conv2D(4, (3, 3), activation='relu')) @@ -189,15 +323,15 @@ def test_keras_sequential( f.add(tf.keras.layers.Dense(2)) f.build((None, *input_shape)) - _, params = experimental.get_apply_fn_and_params(f) - self._compare_ntks(f, params, trace_axes, diagonal_axes, vmap_axes) + f_jax, params = experimental.get_apply_fn_and_params(f) + self._compare_ntks(f, f_jax, params, trace_axes, diagonal_axes, vmap_axes) @parameterized.product( - f=[ - _f1, - _f2, - _f3, - _f4, + f_f_jax=[ + (_f1, _f1_jax), + (_f2, _f2_jax), + (_f3, _f3_jax), + (_f4, _f4_jax) ], params_shape=[ _input_signature[0].shape @@ -216,15 +350,16 @@ def test_keras_sequential( ) def test_tf_function( self, - f, + f_f_jax, params_shape, trace_axes, diagonal_axes, vmap_axes, ): + f, f_jax = f_f_jax f = tf.function(f, input_signature=_input_signature) params = tf.random.normal(params_shape, seed=4) - self._compare_ntks(f, params, trace_axes, diagonal_axes, vmap_axes) + self._compare_ntks(f, f_jax, params, trace_axes, diagonal_axes, vmap_axes) @parameterized.product( trace_axes=[ @@ -239,15 +374,15 @@ def test_tf_function( (0, None) ] ) - def test_module( + def test_tf_module( self, trace_axes, diagonal_axes, vmap_axes, ): f = _MLP(input_size=5, sizes=[4, 6, 3], name='MLP') - _, params = experimental.get_apply_fn_and_params(f) - self._compare_ntks(f, params, trace_axes, diagonal_axes, vmap_axes) + f_jax, params = experimental.get_apply_fn_and_params(f) + self._compare_ntks(f, f_jax, params, trace_axes, diagonal_axes, vmap_axes) if __name__ == '__main__':