Skip to content

Commit

Permalink
Flax/nnx backend (#440)
Browse files Browse the repository at this point in the history
* add flax v0.8.0 to deps, temporarily from github main branch

* main gps objects as nnx modules

* integrators as nnx dataclasses and some static typing refactoring

* likelihoods as nnx dataclasses modules and some static typing refactoring

* small refactoring

* mean functions as nnx dataclasses modules and some refactoring

* bugfix

* objectives as nnx dataclasses modules

* variational families with nnx

* kernels base with nnx

* wip stationary kernels

* wip nonstationary kernels

* wip non euclidean kernels

* computations with nnx

* rff with nnx

* bugfix

* stationary kernels as normal classes

* nonstationary kernels as normal classes

* noneuclidean kernels as normal classes

* rff as standard class + stationary kernel abstract class for static typing

* started work on parameters

* more objects as normal classes

* gps as normal classes

* integrators as normal classes

* dataset is not a pytree

* removed superfluous inits

* register dataset as pytree

* use parameters here and there

* set active_dims default to 1

* start working on tests

* active_dims defaults to None

* rewrite objectives as functions

Co-authored-by: Daniel Dodd <d.dodd1@lancaster.ac.uk>

* black + isort

* remove objective from cite

* fix dataset repr

* pass tests for variational families

* active_dims defaults to None

* use generic Objective type

* small fixes

* make 'active_dims' required parameter, fix static typing and beartype for parameters, rewrite and pass tests for stationary kernels

* pass tests/test_kernels/test_computation.py

* rewrite tests for nonstationary kernels + pass tests

* adapt to nnx's explicit variables + miscellaneous fixes

* rewrite of objectives as simple functions, [WIP] started rewriting tests

* rewrite and pass tests for objectives

* rewrite fit function

* remove gpjax.base module

* remove base module tests

* rewrite and pass tests for fit

* finish kernels and pass all tests

* pass all tests except decision making

* pass all tests 🚀

* update and run classification notebook (python cells)

* pass doctests

* pass integration tests, more checks to parameters

* linting and formatting

* update barycentres and classification examples

* update project files

* update ruff and make it happy

* lint + format all doc examples

* [skip ci] change how dimensions are specified for kernels, update kernel tests

* [skip ci] api reference looks pretty now, implemented template pattern, improved docstrings

* [skip ci] wip - fixing math rendering in documentation - almost there

* Update notebooks. (#447)

* Update yacht.py

* Update likelihoods_guide.py

* Revert "Update likelihoods_guide.py"

This reverts commit 5f51cfe.

* Update oceanmodelling.py

* Update likelihoods.py (#446)

* Update likelihoods.py

* Update likelihoods.py

* Update likelihoods.py

* Adding tagged parameters and updated notebooks

* Update likelihoods.py (#446)

* Update likelihoods.py

* Update likelihoods.py

* Update likelihoods.py

* Update notebooks

* Fix linting

* Fix missing dep.

* Fix integration test

* Readd docs deps

* Fix docstrings

* Update lockfile

* Update parameter refs

* Fix broken tests

* Remove PyTrees doc

* Failing split order

* NNX update

* add flax v0.8.0 to deps, temporarily from github main branch

* main gps objects as nnx modules

* integrators as nnx dataclasses and some static typing refactoring

* likelihoods as nnx dataclasses modules and some static typing refactoring

* small refactoring

* mean functions as nnx dataclasses modules and some refactoring

* bugfix

* objectives as nnx dataclasses modules

* variational families with nnx

* kernels base with nnx

* wip stationary kernels

* wip nonstationary kernels

* wip non euclidean kernels

* computations with nnx

* rff with nnx

* bugfix

* stationary kernels as normal classes

* nonstationary kernels as normal classes

* noneuclidean kernels as normal classes

* rff as standard class + stationary kernel abstract class for static typing

* started work on parameters

* more objects as normal classes

* gps as normal classes

* integrators as normal classes

* dataset is not a pytree

* removed superfluous inits

* register dataset as pytree

* use parameters here and there

* set active_dims default to 1

* start working on tests

* active_dims defaults to None

* rewrite objectives as functions

Co-authored-by: Daniel Dodd <d.dodd1@lancaster.ac.uk>

* black + isort

* remove objective from cite

* fix dataset repr

* pass tests for variational families

* active_dims defaults to None

* use generic Objective type

* small fixes

* make 'active_dims' required parameter, fix static typing and beartype for parameters, rewrite and pass tests for stationary kernels

* pass tests/test_kernels/test_computation.py

* rewrite tests for nonstationary kernels + pass tests

* adapt to nnx's explicit variables + miscellaneous fixes

* rewrite of objectives as simple functions, [WIP] started rewriting tests

* rewrite and pass tests for objectives

* rewrite fit function

* remove gpjax.base module

* remove base module tests

* rewrite and pass tests for fit

* finish kernels and pass all tests

* pass all tests except decision making

* pass all tests 🚀

* update and run classification notebook (python cells)

* pass doctests

* pass integration tests, more checks to parameters

* linting and formatting

* update barycentres and classification examples

* update project files

* update ruff and make it happy

* lint + format all doc examples

* [skip ci] change how dimensions are specified for kernels, update kernel tests

* [skip ci] api reference looks pretty now, implemented template pattern, improved docstrings

* [skip ci] wip - fixing math rendering in documentation - almost there

* Update notebooks. (#447)

* Update yacht.py

* Update likelihoods_guide.py

* Revert "Update likelihoods_guide.py"

This reverts commit 5f51cfe.

* Update oceanmodelling.py

* Update likelihoods.py (#446)

* Update likelihoods.py

* Update likelihoods.py

* Update likelihoods.py

* Update notebooks

* Adding tagged parameters and updated notebooks

* Fix linting

* Fix missing dep.

* Fix integration test

* Readd docs deps

* Fix docstrings

* Update lockfile

* Update parameter refs

* Fix broken tests

* Remove PyTrees doc

* Failing split order

* NNX update

* rename static dir

* move examples dir in top level

* add _examples generated dir to gitignore

* update pyproject deps

* update mkdocs config

* add examples generation script

* adapt relative paths in md files

* Update Ruff and incorporate changes

* update github workflow for building doc, without executing notebookf for now

* Add backend doc

* Add backend doc

* Add backend doc

* Add replace to transform

* Merge with main

* Update parameters docstring

* Respond to comments

* Fix e2e tests

* Fix mplstyle refs

* bump deps

* Update poetry

* Update poetry

* Fix shutil

* Drop flax base

* add scikit-learn dependency for docs

* bugfix: change directory before running jupytext

* use local mpl style file

* do not use MCMC for classification (it is *very* slow)

* [skip-ci] update github workflows for docs

* Fix split

* Fix split

* Fix split

* Fix xdoctest

* Fix doc

* Add serial build

* Update parameters transform and backend doc

* Update parameters transform and backend doc

* Bump Python

---------

Signed-off-by: Thomas Pinder <tompinder@live.co.uk>
Co-authored-by: Daniel Dodd <d.dodd1@lancaster.ac.uk>
Co-authored-by: Daniel Dodd <daniel_dodd@icloud.com>
Co-authored-by: Thomas Pinder <tompinder@live.co.uk>
Co-authored-by: Thomas-Christie <thomashamish@hotmail.com>
  • Loading branch information
5 people authored Aug 16, 2024
1 parent 7ae0adf commit 9ba68a4
Show file tree
Hide file tree
Showing 139 changed files with 6,152 additions and 9,642 deletions.
10 changes: 2 additions & 8 deletions .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,16 @@ jobs:
- name: Install and configure Poetry
uses: snok/install-poetry@v1
with:
version: 1.2.2
version: 1.5.1
virtualenvs-create: false
virtualenvs-in-project: false
installer-parallel: true

- name: Install LaTex
run: |
sudo apt-get update
sudo apt-get install texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra dvipng cm-super
- name: Build the documentation with MKDocs
run: |
cp docs/examples/gpjax.mplstyle .
poetry install --all-extras --with docs
conda install pandoc
poetry run mkdocs build
poetry run python docs/scripts/gen_examples.py --execute && poetry run mkdocs build
- name: Deploy Page 🚀
uses: JamesIves/github-pages-deploy-action@v4.4.1
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1.3.3
with:
version: 1.4.0
version: 1.5.1

# Configure Poetry to use the virtual environment in the project
- name: Setup Poetry
Expand All @@ -39,7 +39,7 @@ jobs:
# Install the dependencies
- name: Install Package
run: |
poetry install --all-extras --with docs
poetry install --with docs
# Run the unit tests and build the coverage report
- name: Run Integration Tests
Expand Down
19 changes: 2 additions & 17 deletions .github/workflows/test_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,17 @@ jobs:
auto-update-conda: true
python-version: ${{ matrix.python-version }}

# Install katex for math support
- name: Install NPM
uses: actions/setup-node@v3
with:
node-version: 16
- name: Install KaTeX
run: |
npm install katex
- name: Install LaTex
run: |
sudo apt-get update
sudo apt-get install texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra dvipng cm-super
# Install Poetry and build the documentation
- name: Install and configure Poetry
uses: snok/install-poetry@v1
with:
version: 1.2.2
version: 1.5.1
virtualenvs-create: false
virtualenvs-in-project: false
installer-parallel: true

- name: Build the documentation with MKDocs
run: |
cp docs/examples/gpjax.mplstyle .
poetry install --all-extras --with docs
conda install pandoc
poetry run mkdocs build
poetry run python docs/scripts/gen_examples.py --execute && poetry run mkdocs build
11 changes: 7 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ jobs:
python-version: ${{ matrix.python-version }}

# Install Poetry
- name: Install Poetry
uses: snok/install-poetry@v1.3.3
- name: Install and configure Poetry
uses: snok/install-poetry@v1
with:
version: 1.4.0
version: 1.5.1
virtualenvs-create: false
virtualenvs-in-project: false
installer-parallel: true

# Configure Poetry to use the virtual environment in the project
- name: Setup Poetry
Expand All @@ -39,7 +42,7 @@ jobs:
# Install the dependencies
- name: Install Package
run: |
poetry install --with tests
poetry install --with dev
- name: Check docstrings
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,4 @@ package-lock.json
node_modules/

docs/api
docs/_examples
22 changes: 11 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ repos:
language: system
types: [python]
exclude: examples/
- repo: https://github.com/econchick/interrogate
rev: 1.5.0
hooks:
- id: interrogate
args:
[
"gpjax",
"--config",
"pyproject.toml",
]
pass_filenames: false
# - repo: https://github.com/econchick/interrogate
# rev: 1.5.0
# hooks:
# - id: interrogate
# args:
# [
# "gpjax",
# "--config",
# "pyproject.toml",
# ]
# pass_filenames: false
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,9 @@ helped to shape GPJax into the package it is today.
## Notebook examples

> - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/examples/regression/)
> - [**Classification with MCMC**](https://docs.jaxgaussianprocesses.com/examples/classification/)
> - [**Classification**](https://docs.jaxgaussianprocesses.com/examples/classification/)
> - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/collapsed_vi/)
> - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/)
> - [**BlackJax Integration**](https://docs.jaxgaussianprocesses.com/examples/classification/#mcmc-inference)
> - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/examples/classification/#laplace-approximation)
> - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
> - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/)
Expand Down Expand Up @@ -146,13 +145,10 @@ posterior = prior * likelihood
# Define an optimiser
optimiser = ox.adam(learning_rate=1e-2)

# Define the marginal log-likelihood
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))

# Obtain Type 2 MLEs of the hyperparameters
opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
objective=gpx.objectives.conjugate_mll,
train_data=D,
optim=optimiser,
num_iters=500,
Expand Down
Empty file removed benchmarks/__init__.py
Empty file.
25 changes: 0 additions & 25 deletions benchmarks/asv.conf.json

This file was deleted.

99 changes: 0 additions & 99 deletions benchmarks/kernels.py

This file was deleted.

87 changes: 0 additions & 87 deletions benchmarks/objectives.py

This file was deleted.

Loading

0 comments on commit 9ba68a4

Please sign in to comment.