Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jaxopt 2 #402

Merged
merged 13 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ Another way you can contribute to GPJax is through [issue
triaging](https://www.codetriage.com/what). This can include reproducing bug reports,
asking for vital information such as version numbers and reproduction instructions, or
identifying stale issues. If you would like to begin triaging issues, an easy way to get
started is to
started is to
[subscribe to GPJax on CodeTriage](https://www.codetriage.com/jaxgaussianprocesses/gpjax).

As a contributor to GPJax, you are expected to abide by our [code of
conduct](docs/CODE_OF_CONDUCT.md). If you are feel that you have either experienced or
witnessed behaviour that violates this standard, then we ask that you report any such
behaviours though [this form](https://jaxgaussianprocesses.com/contact/) or reach out to
behaviours though [this form](https://jaxgaussianprocesses.com/contact/) or reach out to
one of the project's [_gardeners_](https://docs.jaxgaussianprocesses.com/GOVERNANCE/#roles).

Feel free to join our [Slack
Expand Down
7 changes: 2 additions & 5 deletions docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,10 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
likelihood = gpx.Gaussian(num_datapoints=n)
posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood

opt_posterior, _ = gpx.fit(
opt_posterior, _ = gpx.fit_scipy(
model=posterior,
objective=jax.jit(gpx.ConjugateMLL(negative=True)),
objective=gpx.ConjugateMLL(negative=True),
train_data=D,
optim=ox.adamw(learning_rate=0.01),
num_iters=500,
key=key,
)
latent_dist = opt_posterior.predict(xtest, train_data=D)
return opt_posterior.likelihood(latent_dist)
Expand Down
13 changes: 4 additions & 9 deletions docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
)
import matplotlib.pyplot as plt
import numpy as np
import optax as ox
from simple_pytree import static_field
import tensorflow_probability.substrates.jax as tfp

Expand Down Expand Up @@ -214,13 +213,13 @@ def angular_distance(x, y, c):
return jnp.abs((x - y + c) % (c * 2) - c)


bij = tfb.Chain([tfb.Softplus(), tfb.Shift(np.array(4.0).astype(np.float64))])
bij = tfb.SoftClip(low=jnp.array(4.0, dtype=jnp.float64))


@dataclass
class Polar(gpx.kernels.AbstractKernel):
period: float = static_field(2 * jnp.pi)
tau: float = param_field(jnp.array([4.0]), bijector=bij)
tau: float = param_field(jnp.array([5.0]), bijector=bij)

def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
Expand Down Expand Up @@ -267,14 +266,11 @@ def __call__(
likelihood = gpx.Gaussian(num_datapoints=n)
circular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood

# Optimise GP's marginal log-likelihood using Adam
opt_posterior, history = gpx.fit(
# Optimise GP's marginal log-likelihood using BFGS
opt_posterior, history = gpx.fit_scipy(
model=circular_posterior,
objective=jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
optim=ox.adamw(learning_rate=0.05),
num_iters=500,
key=key,
)

# %% [markdown]
Expand Down Expand Up @@ -314,7 +310,6 @@ def __call__(
ax.plot(angles, mu, label="Posterior mean")
ax.scatter(D.X, D.y, alpha=1, label="Observations")
ax.legend()

# %% [markdown]
# ## System configuration

Expand Down
8 changes: 2 additions & 6 deletions docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import optax as ox

with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
Expand Down Expand Up @@ -134,7 +133,7 @@
# For this reason, we simply perform gradient descent on the GP's marginal
# log-likelihood term as in the
# [regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/).
# We do this using the Adam optimiser provided in `optax`.
# We do this using the BFGS optimiser provided in `scipy` via 'jaxopt'.

# %%
likelihood = gpx.Gaussian(num_datapoints=D.n)
Expand All @@ -155,13 +154,10 @@
# With a posterior defined, we can now optimise the model's hyperparameters.

# %%
opt_posterior, training_history = gpx.fit(
opt_posterior, training_history = gpx.fit_scipy(
model=posterior,
objective=jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
optim=ox.adamw(learning_rate=0.01),
num_iters=1000,
key=key,
)

# %% [markdown]
Expand Down
21 changes: 7 additions & 14 deletions docs/examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

config.update("jax_enable_x64", True)

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook, Float
Expand Down Expand Up @@ -217,8 +216,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# %%
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52(
lengthscale=jnp.array(2.0)
) # Initialise our kernel lengthscale to 2.0
lengthscale=jnp.array(0.1)
) # Initialise our kernel lengthscale to 0.1

prior = gpx.Prior(mean_function=mean, kernel=kernel)

Expand All @@ -235,16 +234,11 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# %%
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(no_opt_posterior, train_data=D)
negative_mll = jit(negative_mll)

opt_posterior, history = gpx.fit(
opt_posterior, history = gpx.fit_scipy(
model=no_opt_posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=2000,
safe=True,
key=key,
)


Expand Down Expand Up @@ -524,7 +518,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
mean = gpx.mean_functions.Zero()
rbf_kernel = gpx.kernels.RBF(lengthscale=100.0)
periodic_kernel = gpx.kernels.Periodic()
linear_kernel = gpx.kernels.Linear()
linear_kernel = gpx.kernels.Linear(variance=0.001)
sum_kernel = gpx.kernels.SumKernel(kernels=[linear_kernel, periodic_kernel])
final_kernel = gpx.kernels.SumKernel(kernels=[rbf_kernel, sum_kernel])

Expand All @@ -540,18 +534,17 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# %%
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(posterior, train_data=D)
negative_mll = jit(negative_mll)

opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=1000,
safe=True,
optim=ox.adamw(learning_rate=1e-2),
num_iters=500,
key=key,
)


# %% [markdown]
# Now we can obtain the model's prediction over a period of time which includes the
# training data, as well as 8 years before and after the training data:
Expand Down
20 changes: 4 additions & 16 deletions docs/examples/oceanmodelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)
from matplotlib import rcParams
import matplotlib.pyplot as plt
import optax as ox
import pandas as pd
import tensorflow_probability as tfp

Expand Down Expand Up @@ -239,30 +238,19 @@ def initialise_gp(kernel, mean, dataset):


# %% [markdown]
# With a model now defined, we can proceed to optimise the hyperparameters of our likelihood over $D_0$. This is done by minimising the MLL using `optax`. We also plot its value at each step to visually confirm that we have found the minimum. See the [introduction to Gaussian Processes](https://docs.jaxgaussianprocesses.com/examples/intro_to_gps/) notebook for more information on optimising the MLL.
# With a model now defined, we can proceed to optimise the hyperparameters of our likelihood over $D_0$. This is done by minimising the MLL using `BFGS`. We also plot its value at each step to visually confirm that we have found the minimum. See the [introduction to Gaussian Processes](https://docs.jaxgaussianprocesses.com/examples/intro_to_gps/) notebook for more information on optimising the MLL.


# %%
def optimise_mll(posterior, dataset, NIters=1000, key=key, plot_history=True):
def optimise_mll(posterior, dataset, NIters=1000, key=key):
# define the MLL using dataset_train
objective = gpx.objectives.ConjugateMLL(negative=True)
# Optimise to minimise the MLL
optimiser = ox.adam(learning_rate=0.1)
opt_posterior, history = gpx.fit(
opt_posterior, history = gpx.fit_scipy(
model=posterior,
objective=objective,
train_data=dataset,
optim=optimiser,
num_iters=NIters,
safe=True,
key=key,
)
# plot MLL value at each iteration
if plot_history:
fig, ax = plt.subplots(1, 1)
ax.plot(history, color=colors[1])
ax.set(xlabel="Training iteration", ylabel="Negative MLL")

return opt_posterior


Expand Down Expand Up @@ -471,7 +459,7 @@ def __call__(
# Redefine Gaussian process with Helmholtz kernel
kernel = HelmholtzKernel()
helmholtz_posterior = initialise_gp(kernel, mean, dataset_train)
# Optimise hyperparameters using optax
# Optimise hyperparameters using BFGS
opt_helmholtz_posterior = optimise_mll(helmholtz_posterior, dataset_train)


Expand Down
18 changes: 2 additions & 16 deletions docs/examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,30 +210,16 @@
# accelerate training.

# %% [markdown]
# We can now define an optimiser with `optax`. For this example we'll use the `adam`
# We can now define an optimiser. For this example we'll use the `bfgs`
# optimiser.

# %%
opt_posterior, history = gpx.fit(
opt_posterior, history = gpx.fit_scipy(
model=posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=500,
safe=True,
key=key,
)

# %% [markdown]
# The calling of `fit` returns two objects: the optimised posterior and a history of
# training losses. We can plot the training loss to see how the optimisation has
# progressed.

# %%
fig, ax = plt.subplots()
ax.plot(history, color=cols[1])
ax.set(xlabel="Training iteration", ylabel="Negative marginal log likelihood")

# %% [markdown]
# ## Prediction
#
Expand Down
18 changes: 2 additions & 16 deletions docs/examples/regression_mo.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,30 +228,16 @@
# accelerate training.

# %% [markdown]
# We can now define an optimiser with `optax`. For this example we'll use the `adam`
# We can now define an optimiser with `scipy`. For this example we'll use the `BFGS`
# optimiser.

# %%
opt_posterior, history = gpx.fit(
opt_posterior, history = gpx.fit_scipy(
model=posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=500,
safe=True,
key=key,
)

# %% [markdown]
# The calling of `fit` returns two objects: the optimised posterior and a history of
# training losses. We can plot the training loss to see how the optimisation has
# progressed.

# %%
fig, ax = plt.subplots()
ax.plot(history, color=cols[1])
ax.set(xlabel="Training iteration", ylabel="Negative marginal log likelihood")

# %% [markdown]
# ## Prediction
#
Expand Down
13 changes: 5 additions & 8 deletions docs/examples/yacht.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import optax as ox
import pandas as pd
from sklearn.metrics import (
mean_squared_error,
Expand Down Expand Up @@ -169,7 +168,9 @@
# %%
n_train, n_covariates = scaled_Xtr.shape
kernel = gpx.RBF(
active_dims=list(range(n_covariates)), lengthscale=np.ones((n_covariates,))
active_dims=list(range(n_covariates)),
variance=np.var(scaled_ytr),
lengthscale=0.1 * np.ones((n_covariates,)),
)
meanf = gpx.mean_functions.Zero()
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
Expand All @@ -182,21 +183,17 @@
# ### Model Optimisation
#
# With a model now defined, we can proceed to optimise the hyperparameters of our
# model using Optax.
# model using Scipy.

# %%
training_data = gpx.Dataset(X=scaled_Xtr, y=scaled_ytr)

negative_mll = jit(gpx.ConjugateMLL(negative=True))
optimiser = ox.adamw(0.05)

opt_posterior, history = gpx.fit(
opt_posterior, history = gpx.fit_scipy(
model=posterior,
objective=negative_mll,
train_data=training_data,
optim=ox.adamw(learning_rate=0.05),
num_iters=500,
key=key,
)

# %% [markdown]
Expand Down
6 changes: 5 additions & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
)
from gpjax.citation import cite
from gpjax.dataset import Dataset
from gpjax.fit import fit
from gpjax.fit import (
fit,
fit_scipy,
)
from gpjax.gps import (
Prior,
construct_posterior,
Expand Down Expand Up @@ -87,6 +90,7 @@
"decision_making",
"kernels",
"fit",
"fit_scipy",
"Prior",
"construct_posterior",
"integrators",
Expand Down
Loading