Skip to content

Commit

Permalink
BEGIN_PUBLIC
Browse files Browse the repository at this point in the history
1) Various documentation improvements:
- use proper RST syntax for warning/see-also panes.
- set default role to `code` globally to better render `code blocks` in auto-generated pages.
- change the `code` CSS style.
- better document enums.
- add some missing public enums to `stax`.
- use RST :roles: and intersphinx to create clickable references to NT and JAX/numpy/etc objects.
- Make `Kernel` public, and remove a circular dependency in typing -> Kernel -> utils -> typing.
- Fix interactive python `>>>` code blocks to not create extra `>>>` or spaces when copy-pasting.
- Add citations and references to all papers that contributed to NT.
- Add basic `experimental` documentation.
- Minor typing and linting fixes.

2) Fix a bug in the empirical kernel of grouped CNN, and add respective tests.

3) Add more tests to Tensorflow NTK, comparing TF and JAX NTKs.

4) Reduce some sizes to deflake tests.

Tests: http://sponge2/fa1e9c8a-6fde-45d9-b301-93c7c47d28d8
PiperOrigin-RevId: 457065540
  • Loading branch information
romanngg committed Jun 24, 2022
1 parent cff2f81 commit eec18f5
Show file tree
Hide file tree
Showing 37 changed files with 1,650 additions and 940 deletions.
32 changes: 32 additions & 0 deletions CITATION
Original file line number Diff line number Diff line change
@@ -1,7 +1,39 @@
# Infinite width NTK/NNGP:
@inproceedings{neuraltangents2020,
title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python},
author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
booktitle={International Conference on Learning Representations},
year={2020},
pdf={https://arxiv.org/abs/1912.02803},
url={https://github.com/google/neural-tangents}
}

# Finite width, empirical NTK/NNGP:
@inproceedings{novak2022fast,
title={Fast Finite Width Neural Tangent Kernel},
author={Roman Novak and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
booktitle={International Conference on Machine Learning},
year={2022},
pdf={https://arxiv.org/abs/2206.08720},
url={https://github.com/google/neural-tangents}
}

# Attention and variable-length inputs:
@inproceedings{hron2020infinite,
title={Infinite attention: NNGP and NTK for deep attention networks},
author={Jiri Hron and Yasaman Bahri and Jascha Sohl-Dickstein and Roman Novak},
booktitle={International Conference on Machine Learning},
year={2020},
pdf={https://arxiv.org/abs/2006.10540},
url={https://github.com/google/neural-tangents}
}

# Infinite-width "standard" parameterization:
@misc{sohl2020on,
title={On the infinite width limit of neural networks with a standard parameterization},
author={Jascha Sohl-Dickstein and Roman Novak and Samuel S. Schoenholz and Jaehoon Lee},
publisher = {arXiv},
year={2020},
pdf={https://arxiv.org/abs/2001.07301},
url={https://github.com/google/neural-tangents}
}
39 changes: 38 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

Freedom of thought is fundamental to all of science. Right now, our freedom is being suppressed with carpet bombing of civilians in Ukraine. **Don't be against the war - fight against the war!** Support Ukraine at **[stopputin.net](https://stopputin.net/)**.

### News

Our paper "[Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720)" has been accepted to ICML2022, and the respective code [submitted](https://github.com/romanngg/neural-tangents/commit/60b6c16758652f4526536409b5bc90602287b868) (available starting from version `0.6.0`).

# Neural Tangents
[**ICLR 2020 Video**](https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html)
| [**Paper**](https://arxiv.org/abs/1912.02803)
Expand Down Expand Up @@ -60,6 +64,7 @@ An easy way to get started with Neural Tangents is by playing around with the fo
- Empirical NTK:
- [Fully-connected network](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_fcn.ipynb)
- [FLAX ResNet18](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_resnet.ipynb)
- [Experimental: Tensorflow ResNet50](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/experimental/empirical_ntk_resnet_tf.ipynb)

## Installation

Expand Down Expand Up @@ -528,14 +533,46 @@ to the list!

## Citation

If you use the code in a publication, please cite our ICLR 2020 paper:
If you use the code in a publication, please cite our papers:

```
# Infinite width NTK/NNGP:
@inproceedings{neuraltangents2020,
title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python},
author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
booktitle={International Conference on Learning Representations},
year={2020},
pdf={https://arxiv.org/abs/1912.02803},
url={https://github.com/google/neural-tangents}
}
# Finite width, empirical NTK/NNGP:
@inproceedings{novak2022fast,
title={Fast Finite Width Neural Tangent Kernel},
author={Roman Novak and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
booktitle={International Conference on Machine Learning},
year={2022},
pdf={https://arxiv.org/abs/2206.08720},
url={https://github.com/google/neural-tangents}
}
# Attention and variable-length inputs:
@inproceedings{hron2020infinite,
title={Infinite attention: NNGP and NTK for deep attention networks},
author={Jiri Hron and Yasaman Bahri and Jascha Sohl-Dickstein and Roman Novak},
booktitle={International Conference on Machine Learning},
year={2020},
pdf={https://arxiv.org/abs/2006.10540},
url={https://github.com/google/neural-tangents}
}
# Infinite-width "standard" parameterization:
@misc{sohl2020on,
title={On the infinite width limit of neural networks with a standard parameterization},
author={Jascha Sohl-Dickstein and Roman Novak and Samuel S. Schoenholz and Jaehoon Lee},
publisher = {arXiv},
year={2020},
pdf={https://arxiv.org/abs/2001.07301},
url={https://github.com/google/neural-tangents}
}
```
Expand Down
5 changes: 5 additions & 0 deletions docs/_static/style.css
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
.wy-nav-content {
max-width: none;
}

.rst-content code.literal, .rst-content tt.literal {
color: #404040;
white-space: normal
}
4 changes: 2 additions & 2 deletions docs/batching.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
:github_url: https://github.com/google/neural-tangents/tree/main/docs/batching.rst

.. default-role:: code


`nt.batch` -- using multiple devices
============================================================

.. default-role:: code

.. automodule:: neural_tangents._src.batching
.. automodule:: neural_tangents
:noindex:
Expand Down
9 changes: 7 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
'python': ('https://docs.python.org/3/', None),
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None),
'jax': ('https://jax.readthedocs.io/en/latest/', None),
}


Expand Down Expand Up @@ -172,16 +173,20 @@
'Miscellaneous'),
]


# add_module_names = False


default_role = 'code'


def remove_module_docstring(app, what, name, obj, options, lines):
if what == "module" and name == "neural_tangents":
if what == 'module' and name == 'neural_tangents':
del lines[:]


def setup(app):
app.connect("autodoc-process-docstring", remove_module_docstring)
app.connect('autodoc-process-docstring', remove_module_docstring)
app.add_css_file('style.css')


Expand Down
10 changes: 3 additions & 7 deletions docs/empirical.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
:github_url: https://github.com/google/neural-tangents/tree/main/docs/empirical.rst

.. default-role:: code


`nt.empirical` -- finite NNGP and NTK
======================================
Expand All @@ -21,13 +21,9 @@ Finite-width NNGP and/or NTK kernel functions.

NTK implementation
--------------------------------------
An `IntEnum` specifying NTK implementation method.

.. autosummary::
:toctree: _autosummary

NtkImplementation
An :class:`enum.IntEnum` specifying NTK implementation method.

.. autoclass:: NtkImplementation

Linearization and Taylor expansion
--------------------------------------
Expand Down
24 changes: 24 additions & 0 deletions docs/experimental.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
:github_url: https://github.com/google/neural-tangents/tree/main/docs/experimental.rst



`nt.experimental` -- prototypes
======================================

.. warning::
This module contains new highly-experimental prototypes. Please beware that they are not properly tested, not supported, and may suffer from sub-optimal performance. Use at your own risk!

.. automodule:: neural_tangents.experimental
.. currentmodule:: neural_tangents.experimental

Kernel functions
--------------------------------------
Finite-width NTK kernel function *in Tensorflow*. See the `Python <https://github.com/google/neural-tangents/blob/main/examples/experimental/empirical_ntk_tf.py>`_ and `Colab <https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/experimental/empirical_ntk_resnet_tf.ipynb>`_ usage examples.

.. autofunction:: empirical_ntk_fn_tf

Helper functions
--------------------------------------
A helper function to convert Tensorflow stateful models into functional-style, stateless `apply_fn(params, x)` forward pass function and extract the respective `params`.

.. autofunction:: get_apply_fn_and_params
19 changes: 15 additions & 4 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
:github_url: https://github.com/google/neural-tangents/tree/main/docs/index.rst

.. default-role:: code


Neural Tangents Reference
===========================================
Expand All @@ -24,6 +24,7 @@ neural networks (a.k.a. NTK, NNGP).

kernel
typing
experimental

.. toctree::
:maxdepth: 2
Expand All @@ -34,15 +35,25 @@ neural networks (a.k.a. NTK, NNGP).
Function Space Linearization <https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/function_space_linearization.ipynb>
Neural Network Phase Diagram <https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/phase_diagram.ipynb>
Performance Benchmarks <https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/myrtle_kernel_with_neural_tangents.ipynb>
Finite Width NTK <https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_resnet.ipynb>

.. toctree::
:maxdepth: 2
:caption: Papers:

Neural Tangents: Fast and Easy Infinite Neural Networks in Python <https://arxiv.org/abs/1912.02803>
Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>
Infinite attention: NNGP and NTK for deep attention networks <https://arxiv.org/abs/2006.10540>
On the infinite width limit of neural networks with a standard parameterization <https://arxiv.org/abs/2001.07301>

.. toctree::
:maxdepth: 2
:caption: Other Resources:

GitHub <https://github.com/google/neural-tangents>
Paper <https://arxiv.org/abs/1912.02803>
Video <https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html>
Neural Tangents Video <https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html>
Finite Width NTK Video <https://youtu.be/8MWOhYg89fY?t=10984>
Wikipedia <https://en.wikipedia.org/wiki/Large_width_limits_of_neural_networks>
GitHub <https://github.com/google/neural-tangents>

Indices and tables
==================
Expand Down
6 changes: 3 additions & 3 deletions docs/kernel.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
:github_url: https://github.com/google/neural-tangents/tree/main/docs/kernel.rst

.. default-role:: code

`Kernel` dataclass

:class:`~neural_tangents.Kernel` dataclass
=============================================================

.. automodule:: neural_tangents._src.utils.kernel
.. autoclass:: neural_tangents.Kernel
:members:
4 changes: 2 additions & 2 deletions docs/monte_carlo.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
:github_url: https://github.com/google/neural-tangents/tree/main/docs/monte_carlo.rst

.. default-role:: code


`nt.monte_carlo_kernel_fn` - MC Sampling
=========================================

.. default-role:: code

.. automodule:: neural_tangents._src.monte_carlo
.. automodule:: neural_tangents
:members: monte_carlo_kernel_fn
2 changes: 1 addition & 1 deletion docs/predict.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
:github_url: https://github.com/google/neural-tangents/tree/main/docs/predict.rst

.. default-role:: code


`nt.predict` -- inference w/ NNGP & NTK
=============================================================
Expand Down
5 changes: 4 additions & 1 deletion docs/stax.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
:github_url: https://github.com/google/neural-tangents/tree/main/docs/stax.rst

.. default-role:: code


`nt.stax` -- infinite NNGP and NTK
===========================================


.. automodule:: neural_tangents.stax


Expand Down Expand Up @@ -99,6 +100,8 @@ Enums for specifying layer properties. Strings can be used in their place.
.. autosummary::
:toctree: _autosummary

AggregateImplementation
AttentionMechanism
Padding
PositionalEmbedding

Expand Down
2 changes: 1 addition & 1 deletion docs/typing.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
:github_url: https://github.com/google/neural-tangents/tree/main/docs/typing.rst

.. default-role:: code


Typing
=============================================================
Expand Down
3 changes: 3 additions & 0 deletions examples/empirical_ntk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
All implementations apply to any differentiable functions, (not necessarily ones
constructed with Neural Tangents).
For details about the empirical (finite width) NTK computation, please see
"`Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_".
"""

from absl import app
Expand Down
10 changes: 9 additions & 1 deletion examples/experimental/empirical_ntk_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Minimal highly-experimental Tensorflow NTK example."""
"""Minimal highly-experimental Tensorflow NTK example.
Specifically, Tensorflow NTK appears to have very long compile times (but OK
runtime), is prone to triggering XLA errors, and does not distinguish between
trainable and non-trainable parameters of the model.
For details about the empirical (finite width) NTK computation, please see
"`Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_".
"""

from absl import app
import neural_tangents as nt
Expand Down
8 changes: 2 additions & 6 deletions neural_tangents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,15 @@

__version__ = '0.6.0'


from . import experimental
from . import predict
from . import stax

from ._src.batching import batch

from ._src.empirical import empirical_kernel_fn
from ._src.empirical import empirical_nngp_fn
from ._src.empirical import empirical_ntk_fn
from ._src.empirical import linearize
from ._src.empirical import NtkImplementation
from ._src.empirical import taylor_expand

from ._src.monte_carlo import monte_carlo_kernel_fn

from . import experimental
from ._src.utils.kernel import Kernel
Loading

0 comments on commit eec18f5

Please sign in to comment.