From 992f3bc18e0f853c4573807e2efe2a2594c63930 Mon Sep 17 00:00:00 2001 From: Thomas-Christie Date: Wed, 5 Jul 2023 16:02:34 +0100 Subject: [PATCH 1/4] Add introductory kernel notebook and change style file path in notebooks Main change is to add a notebook introducing the concept of a kernel for those new to Gaussian processes. This focuses on mathematical intuition and introduces useful concepts such as covariance matrices and positive-definiteness. Also noted that the README file for writing documentation hadn't been updated since switching to using MkDocs, so updated it to reflect these changes. Also found that the relative path to the style file for the notebooks caused issues when running the `poetry run mkdocs serve` command. Now instead of using a relative path to the style file, we point to the URL of the style file directly. I have copied the style file into the `_static` directory, and a future PR will point to this URL instead. Finally made a few other minor edits: - Updated docstring for the periodic kernel as it was incomplete. - Made some minor fixes to a few other docstrings I found flaws in. --- docs/_static/gpjax.mplstyle | 54 +++++ docs/examples/README.md | 46 ++-- docs/examples/barycentres.py | 4 +- docs/examples/classification.py | 4 +- docs/examples/collapsed_vi.py | 4 +- docs/examples/deep_kernels.py | 4 +- docs/examples/graph_kernels.py | 4 +- docs/examples/intro_to_gps.py | 4 +- docs/examples/intro_to_kernels.py | 328 +++++++++++++++++++++++++++ docs/examples/kernels.py | 4 +- docs/examples/likelihoods_guide.py | 4 +- docs/examples/poisson.py | 4 +- docs/examples/regression.py | 4 +- docs/examples/spatial.py | 4 +- docs/examples/uncollapsed_vi.py | 4 +- docs/examples/yacht.py | 4 +- gpjax/kernels/stationary/matern12.py | 2 +- gpjax/kernels/stationary/matern32.py | 4 +- gpjax/kernels/stationary/periodic.py | 6 +- gpjax/kernels/stationary/rbf.py | 2 +- mkdocs.yml | 5 +- 21 files changed, 456 insertions(+), 43 deletions(-) create mode 100644 docs/_static/gpjax.mplstyle create mode 100644 docs/examples/intro_to_kernels.py diff --git a/docs/_static/gpjax.mplstyle b/docs/_static/gpjax.mplstyle new file mode 100644 index 000000000..b63425a9f --- /dev/null +++ b/docs/_static/gpjax.mplstyle @@ -0,0 +1,54 @@ +figure.figsize: 5.5, 2.5 +figure.constrained_layout.use: True +figure.autolayout: False +savefig.bbox: tight +figure.dpi: 120 + +# Axes +axes.spines.left: True # display axis spines +axes.spines.bottom: True +axes.spines.top: False +axes.spines.right: False +axes.grid: true +axes.axisbelow: true + +### Fonts +mathtext.fontset: cm +font.family: serif +font.serif: Computer Modern Roman +font.size: 10 +text.usetex: True + +# Axes ticks +ytick.left: True +xtick.bottom: True +xtick.direction: out +ytick.direction: out + +# Colour palettes +axes.prop_cycle: cycler('color', ['2F83B4','B5121B', 'F77F00', '0B6E4F', '7A68A6', 'C5BB36', '8c564b', 'e377c2']) +lines.color: B5121B +scatter.marker: x +image.cmap: inferno + +### Grids +grid.linestyle: - +grid.linewidth: 0.2 +grid.color: cbcbcb + +### Legend +legend.frameon: True +legend.loc: best +legend.fontsize: 8 +legend.fancybox: True +legend.scatterpoints: 1 +legend.numpoints: 1 + +patch.antialiased: True + +# set text objects edidable in Adobe Illustrator +pdf.fonttype: 42 +ps.fonttype: 42 + +# no background +savefig.transparent: True diff --git a/docs/examples/README.md b/docs/examples/README.md index 82e6a510f..5cdfdd9ce 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -5,19 +5,21 @@ https://docs.jaxgaussianprocesses.com/ # How to build the docs -1. Install the requirements using `pip install -r docs/requirements.txt` +1. Ensure you have installed the requirements using `poetry install` in the root directory. 2. Make sure `pandoc` is installed -3. Run the make script `make html` +3. Run the command `poetry run mkdocs serve` in the root directory. -The corresponding HTML files can then be found in `docs/_build/html/`. +The documentation will then be served at an IP address printed, which can then be opened in +a browser of you choice e.g. `Serving on http://127.0.0.1:8000/`. # How to write code documentation -Our documentation it is written in ReStructuredText for Sphinx. This is a -meta-language that is compiled into online documentation. For more details see -[Sphinx's documentation](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html). -As a result, our docstrings adhere to a specific syntax that has to be kept in -mind. Below we provide some guidelines. +Our documentation is generated using [MkDocs](https://www.mkdocs.org/). This automatically creates online documentation +from docstrings, with full support for Markdown. Longer tutorial-style notebooks are also converted to webpages by MkDocs, +with these notebooks being stored in the `docs/examples` directory. If you write a new notebook and wish to add it to +the documentation website, add it to the `nav` section of the `mkdocs.yml` file found in the root directory. + +Below we provide some guidelines for writing docstrings. ## How much information to put in a docstring @@ -53,16 +55,13 @@ class Prior(AbstractPrior): [mean](https://docs.jaxgaussianprocesses.com/api/mean_functions/) and [kernel](https://docs.jaxgaussianprocesses.com/api/kernels/base/) function. - A Gaussian process prior parameterised by a mean function :math:`m(\\cdot)` and a kernel - function :math:`k(\\cdot, \\cdot)` is given by - - .. math:: - - p(f(\\cdot)) = \mathcal{GP}(m(\\cdot), k(\\cdot, \\cdot)). + A Gaussian process prior parameterised by a mean function $`m(\cdot)`$ and a kernel + function $`k(\cdot, \cdot)`$ is given by + $`p(f(\cdot)) = \mathcal{GP}(m(\cdot), k(\cdot, \cdot))`$. - To invoke a ``Prior`` distribution, only a kernel function is required. By default, - the mean function will be set to zero. In general, this assumption will be reasonable - assuming the data being modelled has been centred. + To invoke a `Prior` distribution, only a kernel function is required. By + default, the mean function will be set to zero. In general, this assumption + will be reasonable assuming the data being modelled has been centred. Example: >>> import gpjax as gpx @@ -84,10 +83,15 @@ class Prior(AbstractPrior): ### Documentation syntax -A helpful cheatsheet for writing restructured text can be found -[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst). In addition to that, we adopt the following convention when documenting -`` objects. +We adopt the following convention when documenting objects: * Class attributes should be specified using the `Attributes:` tag. * Method argument should be specified using the `Args:` tags. -* All attributes and arguments should have types. +* Values returned by a method should be specified using the `Returns:` tag. +* All attributes, arguments and returned values should have types. + +!!! attention "Note" + + Inline math in docstrings needs to be rendered within both `$` and `` symbols to be correctly rendered by MkDocs. + For instance, where one would typically write `$k(x,y)$` in standard LaTeX, in docstrings you are required to + write ``$`k(x,y)`$`` in order for the math to be correctly rendered by MkDocs. diff --git a/docs/examples/barycentres.py b/docs/examples/barycentres.py index dca5dc9e5..b5235f851 100644 --- a/docs/examples/barycentres.py +++ b/docs/examples/barycentres.py @@ -34,7 +34,9 @@ key = jr.PRNGKey(123) -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/docs/examples/classification.py b/docs/examples/classification.py index 9d1156c2e..59775abe0 100644 --- a/docs/examples/classification.py +++ b/docs/examples/classification.py @@ -52,7 +52,9 @@ tfd = tfp.distributions identity_matrix = jnp.eye key = jr.PRNGKey(123) -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/docs/examples/collapsed_vi.py b/docs/examples/collapsed_vi.py index f6071d858..8dd442e36 100644 --- a/docs/examples/collapsed_vi.py +++ b/docs/examples/collapsed_vi.py @@ -44,7 +44,9 @@ import gpjax as gpx key = jr.PRNGKey(123) -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/docs/examples/deep_kernels.py b/docs/examples/deep_kernels.py index 14215a355..d36a38d1e 100644 --- a/docs/examples/deep_kernels.py +++ b/docs/examples/deep_kernels.py @@ -45,7 +45,9 @@ from gpjax.kernels.computations import AbstractKernelComputation key = jr.PRNGKey(123) -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/docs/examples/graph_kernels.py b/docs/examples/graph_kernels.py index 386999b74..1ef5d15ad 100644 --- a/docs/examples/graph_kernels.py +++ b/docs/examples/graph_kernels.py @@ -28,7 +28,9 @@ import gpjax as gpx key = jr.PRNGKey(123) -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/docs/examples/intro_to_gps.py b/docs/examples/intro_to_gps.py index 3bbf2adab..13cff8fb6 100644 --- a/docs/examples/intro_to_gps.py +++ b/docs/examples/intro_to_gps.py @@ -109,7 +109,9 @@ import tensorflow_probability.substrates.jax as tfp from docs.examples.utils import confidence_ellipse -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] tfd = tfp.distributions diff --git a/docs/examples/intro_to_kernels.py b/docs/examples/intro_to_kernels.py new file mode 100644 index 000000000..b093b805f --- /dev/null +++ b/docs/examples/intro_to_kernels.py @@ -0,0 +1,328 @@ +# %% [markdown] +# # Introduction to Kernels + +# %% [markdown] +# In this guide we provide an introduction to kernels, and the role they play in Gaussian process models. + +# %% +# Enable Float64 for more stable matrix inversions. +from jax.config import config + +config.update("jax_enable_x64", True) + +from jax import jit +import jax.numpy as jnp +import jax.random as jr +from jaxtyping import install_import_hook +import matplotlib as mpl +import matplotlib.pyplot as plt +import optax as ox +from docs.examples.utils import clean_legend + +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx + +key = jr.PRNGKey(42) +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) +cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + +# %% [markdown] +# Using Gaussian Processes (GPs) to model functions can offer several advantages over alternative methods, such as deep neural networks. One key advantage is their rich quantification of uncertainty; not only do they provide *point estimates* for the values taken by a function throughout its domain, but they provide a full predictive posterior *distribution* over the range of values the function may take. This rich quantification of uncertainty is useful in many applications, such as Bayesian optimisation, which relies on being able to make *uncertainty-aware* decisions. +# +# However, another advantage of GPs is the ability for one to place *priors* on the functions being modelled. For instance, one may know that the underlying function being modelled observes certain characteristics, such as being *periodic* or having a certain level of *smoothness*. The *kernel*, or *covariance function*, is the primary means through which one is able to encode such prior knowledge about the function being modelled. This enables one to equip the GP with inductive biases which enable it to learn from data more efficiently, whilst generalising to unseen data more effectively. +# +# In this notebook we'll develop some intuition for what kinds of priors are encoded through the use of different kernels, and how this can be useful when modelling different types of functions. + +# %% [markdown] +# ## Introducing a Common Family of Kernels - The Matérn Family + +# %% [markdown] +# Intuitively, the kernel defines the notion of *similarity* between the value taken at two points, $\mathbf{x}$ and $\mathbf{x}'$, by a function $f$, and will be denoted as $k(\mathbf{x}, \mathbf{x}')$: +# +# $$k(\mathbf{x}, \mathbf{x}') = \text{Cov}[f(\mathbf{x}), f(\mathbf{x}')]$$ +# +# One would expect that, given a previously unobserved test point $\mathbf{x}^*$, training points which are *closest* to this unobserved point will be most similar to it. As such, the kernel is used to define this notion of similarity within the GP framework. It tends to be up to the user to select a kernel which is appropriate for the function being modelled. +# +# One of the most widely used families of kernels is the Matérn family. These kernels take on the following form: +# +# $$k_{\nu}(\mathbf{x}, \mathbf{x'}) = \sigma^2 \frac{2^{1 - \nu}}{\Gamma(\nu)}\left(\sqrt{2\nu} \frac{|\mathbf{x} - \mathbf{x'}|}{\kappa}\right)^{\nu} K_{\nu} \left(\sqrt{2\nu} \frac{|\mathbf{x} - \mathbf{x'}|}{\kappa}\right)$$ +# +# where $K_{\nu}$ is a modified Bessel function, $\nu$, $\kappa$ and $\sigma^2$ are hyperparameters specifying the mean-square differentiability, lengthscale and variability respectively, and $|\cdot|$ is used to denote the Euclidean norm. +# +# In the limit of $\nu \to \infty$ this yields the *squared-exponential*, or *radial basis function (RBF)*, kernel, which is infinitely mean-square differentiable: +# +# $$k_{\infty}(\mathbf{x}, \mathbf{x'}) = \sigma^2 \exp\left(-\frac{|\mathbf{x} - \mathbf{x'}|^2}{2\kappa^2}\right)$$ +# +# But what kind of functions does this kernel encode prior knowledge about? Let's take a look at some samples from GP priors defined used Matérn kernels with different values of $\nu$: + +# %% +kernels = [ + gpx.kernels.Matern12(), + gpx.kernels.Matern32(), + gpx.kernels.Matern52(), + gpx.kernels.RBF(), +] +fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(7, 6), tight_layout=True) + +x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) + +meanf = gpx.mean_functions.Zero() + +for k, ax in zip(kernels, axes.ravel()): + prior = gpx.Prior(mean_function=meanf, kernel=k) + rv = prior(x) + y = rv.sample(seed=key, sample_shape=(10,)) + ax.plot(x, y.T, alpha=0.7) + ax.set_title(k.name) + + +# %% [markdown] +# It should be noted that commonly used Matérn kernels use half-integer values of $\nu$, such as $\nu = 1/2$ or $\nu = 5/2$. The fraction is sometimes omitted when naming the kernel, so that $\nu = 1/2$ is referred to as the Matérn12 kernel, and $\nu = 5/2$ is referred to as the Matérn52 kernel. +# +# The plots above clearly show that the choice of $\nu$ has a large impact on the *smoothness* of the functions being modelled by the GP, with functions drawn from GPs defined with the Matérn kernel becoming increasingly smooth as $\nu \to \infty$. More formally, this notion of smoothness is captured through the mean-square differentiability of the function being modelled. Functions sampled from GPs using a Matérn kernel are $k$-times mean-square differentiable, if and only if $\nu > k$. For instance, functions sampled from a GP using a Matérn12 kernel are zero times mean-square differentiable, and functions sampled from a GP using the RBF kernel are infinitely mean-square differentiable. +# +# As an important aside, a general property of the Matérn family of kernels is that they are examples of *stationary* kernels. This means that they only depend on the *displacement* of the two points being compared, $\mathbf{x} - \mathbf{x}'$, and not on their absolute values. This is a useful property to have, as it means that the kernel is invariant to translations in the input space. They also go beyond this, as they only depend on the Euclidean *distance* between the two points being compared, $|\mathbf{x} - \mathbf{x}'|$. Kernels which satisfy this property are known as *isotropic* kernels. This makes the function invariant to all rigid motions in the input space, such as rotations. + +# %% [markdown] +# ## Inferring Kernel Hyperparameters + +# %% [markdown] +# Most kernels have several *hyperparameters*, which we denote $\mathbf{\theta}$, which encode different assumptions about the underlying function being modelled. For the Matérn family descibred above, $\mathbf{\theta} = \{\nu, \kappa, \sigma\}$. A fully Bayesian approach to dealing with hyperparameters would be to place a prior over them, and marginalise over the posterior derived from the data in order to perform predictions. However, this is often computationally very expensive, and so a common approach is to instead *optimise* the hyperparameters by maximising the log marginal likelihood of the data. Given training data $\mathbf{D} = (\mathbf{X}, \mathbf{y})$, assumed to contain some additive Gaussian noise $\epsilon \sim \mathcal{N}(0, \sigma^2)$, the log marginal likelihood of the dataset is defined as: +# +# $$ \begin{aligned} +# \log(p(\mathbf{y} | \mathbf{X}, \boldsymbol{\theta})) &= \log\left(\int p(\mathbf{y} | \mathbf{f}, \mathbf{X}, \boldsymbol{\theta}) p(\mathbf{f} | \mathbf{X}, \boldsymbol{\theta}) d\mathbf{f}\right) \nonumber \\ +# &= - \frac{1}{2} \mathbf{y} ^ \top \left(K(\mathbf{X}, \mathbf{X}) + \sigma^2 \mathbf{I} \right)^{-1} \mathbf{y} - \frac{1}{2} \log |K(\mathbf{X}, \mathbf{X}) + \sigma^2 \mathbf{I}| - \frac{n}{2} \log 2 \pi +# \end{aligned}$$ + +# %% [markdown] +# We'll demonstrate the advantages of being able to infer kernel parameters from the training data by fitting a GP to the widely used [Forrester function](https://www.sfu.ca/~ssurjano/forretal08.html): +# +# $$f(x) = (6x - 2)^2 \sin(12x - 4)$$ + + +# %% +# Forrester function +def forrester(x): + return (6 * x - 2) ** 2 * jnp.sin(12 * x - 4) + + +n = 5 + +training_x = jr.uniform(key=key, minval=0, maxval=1, shape=(n,)).reshape(-1, 1) +training_y = forrester(training_x) +D = gpx.Dataset(X=training_x, y=training_y) + +test_x = jnp.linspace(0, 1, 100).reshape(-1, 1) +test_y = forrester(test_x) + +# %% [markdown] +# First we define our model, using the Matérn32 kernel, and construct our posterior *without* optimising the kernel hyperparameters: + +# %% +mean = gpx.mean_functions.Zero() +kernel = gpx.kernels.Matern32( + lengthscale=jnp.array(2.0) +) # Initialise our kernel lengthscale to 2.0 + +prior = gpx.Prior(mean_function=mean, kernel=kernel) + +likelihood = gpx.Gaussian( + num_datapoints=D.n, obs_noise=jnp.array(1e-6) +) # Our function is noise-free, so we set the observation noise to a very small value +likelihood = likelihood.replace_trainable(obs_noise=False) + +no_opt_posterior = prior * likelihood + +# %% [markdown] +# We can then optimise the hyperparmeters by minimising the negative log marginal likelihood of the data: + +# %% +negative_mll = gpx.objectives.ConjugateMLL(negative=True) +negative_mll(no_opt_posterior, train_data=D) +negative_mll = jit(negative_mll) + +opt_posterior, history = gpx.fit( + model=no_opt_posterior, + objective=negative_mll, + train_data=D, + optim=ox.adam(learning_rate=0.01), + num_iters=2000, + safe=True, + key=key, +) + + +# %% +opt_latent_dist = opt_posterior.predict(test_x, train_data=D) +opt_predictive_dist = opt_posterior.likelihood(opt_latent_dist) + +opt_predictive_mean = opt_predictive_dist.mean() +opt_predictive_std = opt_predictive_dist.stddev() + +fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(5, 6)) +ax1.plot(training_x, training_y, "x", label="Observations", color=cols[0], alpha=0.5) +ax1.fill_between( + test_x.squeeze(), + opt_predictive_mean - 2 * opt_predictive_std, + opt_predictive_mean + 2 * opt_predictive_std, + alpha=0.2, + label="Two sigma", + color=cols[1], +) +ax1.plot( + test_x, + opt_predictive_mean - 2 * opt_predictive_std, + linestyle="--", + linewidth=1, + color=cols[1], +) +ax1.plot( + test_x, + opt_predictive_mean + 2 * opt_predictive_std, + linestyle="--", + linewidth=1, + color=cols[1], +) +ax1.plot( + test_x, test_y, label="Latent function", color=cols[0], linestyle="--", linewidth=2 +) +ax1.plot(test_x, opt_predictive_mean, label="Predictive mean", color=cols[1]) +ax1.set_title("Posterior with Hyperparameter Optimisation") +ax1.legend(loc="center left", bbox_to_anchor=(0.975, 0.5)) + +no_opt_latent_dist = no_opt_posterior.predict(test_x, train_data=D) +no_opt_predictive_dist = no_opt_posterior.likelihood(no_opt_latent_dist) + +no_opt_predictive_mean = no_opt_predictive_dist.mean() +no_opt_predictive_std = no_opt_predictive_dist.stddev() + +ax2.plot(training_x, training_y, "x", label="Observations", color=cols[0], alpha=0.5) +ax2.fill_between( + test_x.squeeze(), + no_opt_predictive_mean - 2 * no_opt_predictive_std, + no_opt_predictive_mean + 2 * no_opt_predictive_std, + alpha=0.2, + label="Two sigma", + color=cols[1], +) +ax2.plot( + test_x, + no_opt_predictive_mean - 2 * no_opt_predictive_std, + linestyle="--", + linewidth=1, + color=cols[1], +) +ax2.plot( + test_x, + no_opt_predictive_mean + 2 * no_opt_predictive_std, + linestyle="--", + linewidth=1, + color=cols[1], +) +ax2.plot( + test_x, test_y, label="Latent function", color=cols[0], linestyle="--", linewidth=2 +) +ax2.plot(test_x, no_opt_predictive_mean, label="Predictive mean", color=cols[1]) +ax2.set_title("Posterior without Hyperparameter Optimisation") +ax2.legend(loc="center left", bbox_to_anchor=(0.975, 0.5)) + +# %% [markdown] +# We can see that optimising the hyperparameters by minimising the negative log marginal likelihood of the data results in a more faithful fit of the GP to the data. In particular, we can observe that the GP using optimised hyperparameters is more accurately able to reflect uncertainty in its predictions, as opposed to the GP using the default parameters, which is overconfident in its predictions. +# +# The lengthscale, $\kappa$, and variance, $\sigma^2$, are shown below, both before and after optimisation: + +# %% +no_opt_lengthscale = no_opt_posterior.prior.kernel.lengthscale +no_opt_variance = no_opt_posterior.prior.kernel.variance +opt_lengthscale = opt_posterior.prior.kernel.lengthscale +opt_variance = opt_posterior.prior.kernel.variance + +print(f"Optimised Lengthscale: {opt_lengthscale} and Variance: {opt_variance}") +print( + f"Non-Optimised Lengthscale: {no_opt_lengthscale} and Variance: {no_opt_variance}" +) + +# %% [markdown] +# ## Expressing Other Priors with Different Kernels + +# %% [markdown] +# Whilst the Matérn kernels are often used as a first choice of kernel, and they often perform well due to their smoothing properties often being well-aligned with the properties of the underlying function being modelled, sometimes more prior knowledge is known about the function being modelled. For instance, it may be known that the function being modelled is *periodic*. In this case, a suitable kernel choice would be the *periodic* kernel: +# +# $$k(\mathbf{x}, \mathbf{x}') = \sigma^2 \exp \left( -\frac{1}{2} \sum_{i=1}^{D} \left(\frac{\sin (\pi (\mathbf{x}_i - \mathbf{x}_i')/p)}{\ell}\right)^2 \right)$$ +# +# with $D$ being the dimensionality of the inputs. +# +# Below we show $10$ samples drawn from a GP prior using the periodic kernel: + +# %% +mean = gpx.mean_functions.Zero() +kernel = gpx.kernels.Periodic() +prior = gpx.Prior(mean_function=mean, kernel=kernel) + +x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) +rv = prior(x) +y = rv.sample(seed=key, sample_shape=(10,)) + +fig, ax = plt.subplots() +ax.plot(x, y.T, alpha=0.7) +ax.set_title("Samples from the Periodic Kernel") +plt.show() + +# %% [markdown] +# In other scenarios, it may be known that the underlying function is *linear*, in which case the *linear* kernel would be a suitable choice: +# +# $$k(\mathbf{x}, \mathbf{x}') = \sigma^2 \mathbf{x}^\top \mathbf{x}'$$ +# +# Unlike the kernels shown above, the linear kernel is *not* stationary, and so it is not invariant to translations in the input space. +# +# Below we show $10$ samples drawn from a GP prior using the linear kernel: + +# %% +mean = gpx.mean_functions.Zero() +kernel = gpx.kernels.Linear() +prior = gpx.Prior(mean_function=mean, kernel=kernel) + +x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) +rv = prior(x) +y = rv.sample(seed=key, sample_shape=(10,)) + +fig, ax = plt.subplots() +ax.plot(x, y.T, alpha=0.7) +ax.set_title("Samples from the Linear Kernel") +plt.show() + +# %% [markdown] +# ## What are the Necessary Conditions for a Valid Kernel? + +# %% [markdown] +# In this guide we have introduced several different kernel functions, $k$, which may make you wonder if any function of two input pairs you construct will make a valid kernel function? Alas, not any function can be used as a kernel function in a GP, and there is a necessary condition a function must satisfy in order to be a valid kernel function. +# +# In order to understand the necessary condition, it is useful to introduce the idea of a *Gram matrix*. As introduced in the [GP introduction notebook](https://docs.jaxgaussianprocesses.com/examples/intro_to_gps/), given $n$ input points, $\mathbf{X} = \{\mathbf{x}_1, \ldots, \mathbf{x}_n\}$, the *Gram matrix* is defined as: +# +# $$K(\mathbf{X}, \mathbf{X}) = \begin{bmatrix} k(\mathbf{x}_1, \mathbf{x}_1) & \cdots & k(\mathbf{x}_1, \mathbf{x}_n) \\ \vdots & \ddots & \vdots \\ k(\mathbf{x}_n, \mathbf{x}_1) & \cdots & k(\mathbf{x}_n, \mathbf{x}_n) \end{bmatrix}$$ +# +# such that $K(\mathbf{X}, \mathbf{X})_{ij} = k(\mathbf{x}_i, \mathbf{x}_j)$. +# +# In order for $k$ to be a valid kernel/covariance function, the corresponding covariance martrix must be *positive semi-definite*. A real $n \times n$ matrix $K$ is positive semi-definite if and only if for all vectors $\mathbf{z} \in \mathbb{R}^n$, $\mathbf{z}^\top K \mathbf{z} \geq 0$. Alternatively, a real $n \times n$ matrix $K$ is positive semi-definite if and only if all of its eigenvalues are non-negative. + +# %% [markdown] +# ## Defining Kernels on Non-Euclidean Spaces +# +# In this notebook, we have focused solely on kernels whose domain resides in Euclidean space. However, what if one wished to work with data whose domain is non-Euclidean? For instance, one may wish to work with graph-structured data, or data which lies on a manifold, or even strings. Fortunately, kernels exist for a wide variety of domains. Whilst this is beyond the scope of this notebook, feel free to checkout out our [notebook on graph kernels](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/) for an introduction on how to define the Matérn kernel on graph-structured data, and there are a wide variety of resources online for learning about defining kernels in other domains. In terms of open-source libraries, the [Geometric Kernels](https://github.com/GPflow/GeometricKernels) library could be a good place to start if you're interested in looking at how these kernels may be implemented, with the additional benefit that it is compatible with GPJax. + +# %% [markdown] +# ## Further Reading +# +# Congratulations on making it this far! We hope that this guide has given you a good introduction to kernels and how they can be used in GPJax. If you're interested in learning more about kernels, we recommend the following resources, which have also been used as inspiration for this guide: +# +# - [Gaussian Processes for Machine Learning](http://www.gaussianprocess.org/gpml/chapters/RW.pdf) - Chapter 4 provides a comprehensive overview of kernels, diving deep into some of the technical details and also providing some kernels defined on non-Euclidean spaces such as strings. +# - David Duvenaud's [Kernel Cookbook](https://www.cs.toronto.edu/~duvenaud/cookbook/) is a great resource for learning about kernels, and also provides some information about some of the pitfalls people commonly encounter when using the Matérn family of kernels. His PhD thesis, [Automatic Model Construction with Gaussian Processes](https://www.cs.toronto.edu/~duvenaud/thesis.pdf), also provides some in-depth recipes for how one may incorporate their prior knowledge when constructing kernels. +# - Finally, please check out our [more advanced kernel guide](https://docs.jaxgaussianprocesses.com/examples/kernels/), which details some more kernels available in GPJax as well as how one may combine kernels together to form more complex kernels. +# +# ## System Configuration + +# %% +# %reload_ext watermark +# %watermark -n -u -v -iv -w -a 'Thomas Christie' diff --git a/docs/examples/kernels.py b/docs/examples/kernels.py index 20e7edf1b..e218a5f77 100644 --- a/docs/examples/kernels.py +++ b/docs/examples/kernels.py @@ -46,7 +46,9 @@ key = jr.PRNGKey(123) tfb = tfp.bijectors -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/docs/examples/likelihoods_guide.py b/docs/examples/likelihoods_guide.py index ebecc0634..f2b8d961c 100644 --- a/docs/examples/likelihoods_guide.py +++ b/docs/examples/likelihoods_guide.py @@ -62,7 +62,9 @@ import tensorflow_probability.substrates.jax as tfp tfd = tfp.distributions -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] key = jr.PRNGKey(123) diff --git a/docs/examples/poisson.py b/docs/examples/poisson.py index 7fbe695d6..da740665f 100644 --- a/docs/examples/poisson.py +++ b/docs/examples/poisson.py @@ -41,7 +41,9 @@ config.update("jax_enable_x64", True) tfd = tfp.distributions key = jr.PRNGKey(123) -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/docs/examples/regression.py b/docs/examples/regression.py index 374514774..bccbb8068 100644 --- a/docs/examples/regression.py +++ b/docs/examples/regression.py @@ -38,7 +38,9 @@ import gpjax as gpx key = jr.PRNGKey(123) -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/docs/examples/spatial.py b/docs/examples/spatial.py index 20a4a03da..140088baf 100644 --- a/docs/examples/spatial.py +++ b/docs/examples/spatial.py @@ -69,7 +69,9 @@ key = jr.PRNGKey(123) -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # Observed temperature data diff --git a/docs/examples/uncollapsed_vi.py b/docs/examples/uncollapsed_vi.py index 3652f63f1..76eae5ec1 100644 --- a/docs/examples/uncollapsed_vi.py +++ b/docs/examples/uncollapsed_vi.py @@ -51,7 +51,9 @@ key = jr.PRNGKey(123) tfb = tfp.bijectors -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/docs/examples/yacht.py b/docs/examples/yacht.py index 48b5c442c..f3a3d2ab8 100644 --- a/docs/examples/yacht.py +++ b/docs/examples/yacht.py @@ -33,7 +33,9 @@ # Enable Float64 for more stable matrix inversions. key = jr.PRNGKey(123) -plt.style.use("./gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/gpjax/kernels/stationary/matern12.py b/gpjax/kernels/stationary/matern12.py index 1a337e018..002de74d7 100644 --- a/gpjax/kernels/stationary/matern12.py +++ b/gpjax/kernels/stationary/matern12.py @@ -49,7 +49,7 @@ def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: Evaluate the kernel on a pair of inputs $`(x, y)`$ with lengthscale parameter $`\ell`$ and variance $`\sigma^2`$. ```math - (x, y) = \sigma^2\exp\Bigg(-\frac{\lvert x-y \rvert}{2\ell^2}\Bigg) + k(x, y) = \sigma^2\exp\Bigg(-\frac{\lvert x-y \rvert}{2\ell^2}\Bigg) ``` Args: diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index da1b34527..ac3b79699 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -51,10 +51,10 @@ def __call__( r"""Compute the Matérn 3/2 kernel between a pair of arrays. Evaluate the kernel on a pair of inputs $`(x, y)`$ with - lengthscale parameter $\ell$ and variance $`\sigma^2`$. + lengthscale parameter $`\ell`$ and variance $`\sigma^2`$. ```math - k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{3}\\lvert x-y \\rvert}{\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{3}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) + k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{3}\lvert x-y \rvert}{\ell^2} \Bigg)\exp\Bigg(-\frac{\sqrt{3}\lvert x-y\rvert}{\ell^2} \Bigg) ``` Args: diff --git a/gpjax/kernels/stationary/periodic.py b/gpjax/kernels/stationary/periodic.py index 695054ed8..1753b82c7 100644 --- a/gpjax/kernels/stationary/periodic.py +++ b/gpjax/kernels/stationary/periodic.py @@ -45,10 +45,10 @@ class Periodic(AbstractKernel): def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: r"""Compute the Periodic kernel between a pair of arrays. - TODO: update docstring - Evaluate the kernel on a pair of inputs $`(x, y)`$ with length-scale parameter $\ell$ and variance $\sigma$. + Evaluate the kernel on a pair of inputs $`(x, y)`$ with length-scale parameter $`\ell`$, variance $`\sigma^2`$ + and period $`p`$. ```math - k(x, y) = \sigma^2 \exp \Bigg( -0.5 \sum_{i=1}^{d} \Bigg) + k(x, y) = \sigma^2 \exp \left( -\frac{1}{2} \sum_{i=1}^{D} \left(\frac{\sin (\pi (x_i - y_i)/p)}{\ell}\right)^2 \right) ``` Args: diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index 262fd7a7d..6f2cd2b56 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -46,7 +46,7 @@ def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: Evaluate the kernel on a pair of inputs $`(x, y)`$ with lengthscale parameter $`\ell`$ and variance $`\sigma^2`$: ```math - k(x,y)=\sigma^2\exp\Bigg(\frac{\lVert x - y \rVert^2_2}{2 \ell^2} \Bigg) + k(x,y)=\sigma^2\exp\Bigg(- \frac{\lVert x - y \rVert^2_2}{2 \ell^2} \Bigg) ``` Args: diff --git a/mkdocs.yml b/mkdocs.yml index da2b80b58..dc091a2ea 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -17,6 +17,7 @@ nav: - 📎 JAX 101 [External]: https://jax.readthedocs.io/en/latest/jax-101/index.html - 💡 Background: - Intro to GPs: examples/intro_to_gps.py + - Intro to Kernels: examples/intro_to_kernels.py - 🎓 Tutorials: - Regression: examples/regression.py - Classification: examples/classification.py @@ -130,11 +131,11 @@ extra: extra_css: - stylesheets/extra.css - stylesheets/permalinks.css - - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.9.0/katex.min.css + - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.8/katex.min.css extra_javascript: # - javascripts/mathjax.js - https://polyfill.io/v3/polyfill.min.js?features=es6 # - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js - - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.9.0/katex.min.js + - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.8/katex.min.js - javascripts/katex.js From 62aea3d13e99c25ac242ab48f2f6dd23b953bab6 Mon Sep 17 00:00:00 2001 From: Thomas-Christie Date: Thu, 6 Jul 2023 09:35:17 +0100 Subject: [PATCH 2/4] Add edits to introductory kernel notebook --- docs/examples/README.md | 55 +++++++++++++-------------- docs/examples/intro_to_kernels.py | 63 +++++++++++++++++++++++++------ gpjax/gps.py | 4 +- 3 files changed, 79 insertions(+), 43 deletions(-) diff --git a/docs/examples/README.md b/docs/examples/README.md index 5cdfdd9ce..a5188c35e 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -1,46 +1,47 @@ # Where to find the docs -The GPJax documentation can be found here: -https://docs.jaxgaussianprocesses.com/ +The GPJax documentation can be found here: https://docs.jaxgaussianprocesses.com/ # How to build the docs -1. Ensure you have installed the requirements using `poetry install` in the root directory. +1. Ensure you have installed the requirements using `poetry install` in the root + directory. 2. Make sure `pandoc` is installed 3. Run the command `poetry run mkdocs serve` in the root directory. -The documentation will then be served at an IP address printed, which can then be opened in -a browser of you choice e.g. `Serving on http://127.0.0.1:8000/`. +The documentation will then be served at an IP address printed, which can then be opened +in a browser of you choice e.g. `Serving on http://127.0.0.1:8000/`. # How to write code documentation -Our documentation is generated using [MkDocs](https://www.mkdocs.org/). This automatically creates online documentation -from docstrings, with full support for Markdown. Longer tutorial-style notebooks are also converted to webpages by MkDocs, -with these notebooks being stored in the `docs/examples` directory. If you write a new notebook and wish to add it to -the documentation website, add it to the `nav` section of the `mkdocs.yml` file found in the root directory. +Our documentation is generated using [MkDocs](https://www.mkdocs.org/). This +automatically creates online documentation from docstrings, with full support for +Markdown. Longer tutorial-style notebooks are also converted to webpages by MkDocs, with +these notebooks being stored in the `docs/examples` directory. If you write a new +notebook and wish to add it to the documentation website, add it to the `nav` section of +the `mkdocs.yml` file found in the root directory. Below we provide some guidelines for writing docstrings. ## How much information to put in a docstring -A docstring should be informative. If in doubt, then it is best to add more -information to a docstring than less. Many users will skim documentation, so -please ensure the opening sentence or two of a docstring contains the core -information. Adding examples and mathematical descriptions to documentation is -highly desirable. +A docstring should be informative. If in doubt, then it is best to add more information +to a docstring than less. Many users will skim documentation, so please ensure the +opening sentence or two of a docstring contains the core information. Adding examples +and mathematical descriptions to documentation is highly desirable. -We are making an active effort within GPJax to improve our documentation. If you -spot any areas where there is missing information within the existing -documentation, then please either raise an issue or -[create a pull request](https://docs.jaxgaussianprocesses.com/contributing/). +We are making an active effort within GPJax to improve our documentation. If you spot +any areas where there is missing information within the existing documentation, then +please either raise an issue or [create a pull +request](https://docs.jaxgaussianprocesses.com/contributing/). ## An example docstring -An example docstring that adheres the principles of GPJax is given below. -The docstring contains a simple, snappy introduction with links to auxiliary -components. More detail is then provided in the form of a mathematical -description and a code example. The docstring is concluded with a description -of the objects attributes with corresponding types. +An example docstring that adheres the principles of GPJax is given below. The docstring +contains a simple, snappy introduction with links to auxiliary components. More detail +is then provided in the form of a mathematical description and a code example. The +docstring is concluded with a description of the objects attributes with corresponding +types. ```python from gpjax.gps import AbstractPrior @@ -59,9 +60,7 @@ class Prior(AbstractPrior): function $`k(\cdot, \cdot)`$ is given by $`p(f(\cdot)) = \mathcal{GP}(m(\cdot), k(\cdot, \cdot))`$. - To invoke a `Prior` distribution, only a kernel function is required. By - default, the mean function will be set to zero. In general, this assumption - will be reasonable assuming the data being modelled has been centred. + To invoke a `Prior` distribution, a kernel and mean function must be specified. Example: >>> import gpjax as gpx @@ -92,6 +91,4 @@ We adopt the following convention when documenting objects: !!! attention "Note" - Inline math in docstrings needs to be rendered within both `$` and `` symbols to be correctly rendered by MkDocs. - For instance, where one would typically write `$k(x,y)$` in standard LaTeX, in docstrings you are required to - write ``$`k(x,y)`$`` in order for the math to be correctly rendered by MkDocs. + Inline math in docstrings needs to be rendered within both `$` and `` symbols to be correctly rendered by MkDocs. For instance, where one would typically write `$k(x,y)$` in standard LaTeX, in docstrings you are required to write ``$`k(x,y)`$`` in order for the math to be correctly rendered by MkDocs. diff --git a/docs/examples/intro_to_kernels.py b/docs/examples/intro_to_kernels.py index b093b805f..f5384c4a2 100644 --- a/docs/examples/intro_to_kernels.py +++ b/docs/examples/intro_to_kernels.py @@ -13,7 +13,7 @@ from jax import jit import jax.numpy as jnp import jax.random as jr -from jaxtyping import install_import_hook +from jaxtyping import install_import_hook, Float import matplotlib as mpl import matplotlib.pyplot as plt import optax as ox @@ -21,6 +21,7 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx +from gpjax.typing import Array key = jr.PRNGKey(42) plt.style.use( @@ -39,17 +40,36 @@ # ## Introducing a Common Family of Kernels - The Matérn Family # %% [markdown] -# Intuitively, the kernel defines the notion of *similarity* between the value taken at two points, $\mathbf{x}$ and $\mathbf{x}'$, by a function $f$, and will be denoted as $k(\mathbf{x}, \mathbf{x}')$: +# Intuitively, for a function $f$, the kernel defines the notion of *similarity* between +# the value of the function at two points, $f(\mathbf{x})$ and $f(\mathbf{x}')$, and +# will be denoted as $k(\mathbf{x}, \mathbf{x}')$: # # $$k(\mathbf{x}, \mathbf{x}') = \text{Cov}[f(\mathbf{x}), f(\mathbf{x}')]$$ # -# One would expect that, given a previously unobserved test point $\mathbf{x}^*$, training points which are *closest* to this unobserved point will be most similar to it. As such, the kernel is used to define this notion of similarity within the GP framework. It tends to be up to the user to select a kernel which is appropriate for the function being modelled. +# One would expect that, given a previously unobserved test point $\mathbf{x}^*$, the +# training points which are *closest* to this unobserved point will be most similar to +# it. As such, the kernel is used to define this notion of similarity within the GP +# framework. It is up to the user to select a kernel which is appropriate for the +# function being modelled. # # One of the most widely used families of kernels is the Matérn family. These kernels take on the following form: # # $$k_{\nu}(\mathbf{x}, \mathbf{x'}) = \sigma^2 \frac{2^{1 - \nu}}{\Gamma(\nu)}\left(\sqrt{2\nu} \frac{|\mathbf{x} - \mathbf{x'}|}{\kappa}\right)^{\nu} K_{\nu} \left(\sqrt{2\nu} \frac{|\mathbf{x} - \mathbf{x'}|}{\kappa}\right)$$ # -# where $K_{\nu}$ is a modified Bessel function, $\nu$, $\kappa$ and $\sigma^2$ are hyperparameters specifying the mean-square differentiability, lengthscale and variability respectively, and $|\cdot|$ is used to denote the Euclidean norm. +# where $K_{\nu}$ is a modified Bessel function, $\nu$, $\kappa$ and $\sigma^2$ are +# hyperparameters specifying the mean-square differentiability, lengthscale and +# variability respectively, and $|\cdot|$ is used to denote the Euclidean norm. +# +# Some commonly used Matérn kernels use half-integer values of $\nu$, such as $\nu = 1/2$ +# or $\nu = 3/2$. The fraction is sometimes omitted when naming the kernel, so that $\nu = +# 1/2$ is referred to as the Matérn12 kernel, and $\nu = 3/2$ is referred to as the +# Matérn32 kernel. When $\nu$ takes in a half-integer value, $\nu = k + 1/2$, the kernel +# can be expressed as the product of a polynomial of order $k$ and an exponential: +# +# $$k_{k + 1/2}(\mathbf{x}, \mathbf{x'}) = \sigma^2 +# \exp\left(-\frac{\sqrt{2\nu}|\mathbf{x} - \mathbf{x'}|}{\kappa}\right) +# \frac{\Gamma(k+1)}{\Gamma(2k+1)} \times \sum_{i= 0}^k \frac{(k+i)!}{i!(k-i)!} +# \left(\frac{(\sqrt{8\nu}|\mathbf{x} - \mathbf{x'}|)}{\kappa}\right)^{k-i}$$ # # In the limit of $\nu \to \infty$ this yields the *squared-exponential*, or *radial basis function (RBF)*, kernel, which is infinitely mean-square differentiable: # @@ -79,8 +99,6 @@ # %% [markdown] -# It should be noted that commonly used Matérn kernels use half-integer values of $\nu$, such as $\nu = 1/2$ or $\nu = 5/2$. The fraction is sometimes omitted when naming the kernel, so that $\nu = 1/2$ is referred to as the Matérn12 kernel, and $\nu = 5/2$ is referred to as the Matérn52 kernel. -# # The plots above clearly show that the choice of $\nu$ has a large impact on the *smoothness* of the functions being modelled by the GP, with functions drawn from GPs defined with the Matérn kernel becoming increasingly smooth as $\nu \to \infty$. More formally, this notion of smoothness is captured through the mean-square differentiability of the function being modelled. Functions sampled from GPs using a Matérn kernel are $k$-times mean-square differentiable, if and only if $\nu > k$. For instance, functions sampled from a GP using a Matérn12 kernel are zero times mean-square differentiable, and functions sampled from a GP using the RBF kernel are infinitely mean-square differentiable. # # As an important aside, a general property of the Matérn family of kernels is that they are examples of *stationary* kernels. This means that they only depend on the *displacement* of the two points being compared, $\mathbf{x} - \mathbf{x}'$, and not on their absolute values. This is a useful property to have, as it means that the kernel is invariant to translations in the input space. They also go beyond this, as they only depend on the Euclidean *distance* between the two points being compared, $|\mathbf{x} - \mathbf{x}'|$. Kernels which satisfy this property are known as *isotropic* kernels. This makes the function invariant to all rigid motions in the input space, such as rotations. @@ -89,7 +107,7 @@ # ## Inferring Kernel Hyperparameters # %% [markdown] -# Most kernels have several *hyperparameters*, which we denote $\mathbf{\theta}$, which encode different assumptions about the underlying function being modelled. For the Matérn family descibred above, $\mathbf{\theta} = \{\nu, \kappa, \sigma\}$. A fully Bayesian approach to dealing with hyperparameters would be to place a prior over them, and marginalise over the posterior derived from the data in order to perform predictions. However, this is often computationally very expensive, and so a common approach is to instead *optimise* the hyperparameters by maximising the log marginal likelihood of the data. Given training data $\mathbf{D} = (\mathbf{X}, \mathbf{y})$, assumed to contain some additive Gaussian noise $\epsilon \sim \mathcal{N}(0, \sigma^2)$, the log marginal likelihood of the dataset is defined as: +# Most kernels have several *hyperparameters*, which we denote $\mathbf{\theta}$, which encode different assumptions about the underlying function being modelled. For the Matérn family described above, $\mathbf{\theta} = \{\nu, \kappa, \sigma\}$. A fully Bayesian approach to dealing with hyperparameters would be to place a prior over them, and marginalise over the posterior derived from the data in order to perform predictions. However, this is often computationally very expensive, and so a common approach is to instead *optimise* the hyperparameters by maximising the log marginal likelihood of the data. Given training data $\mathbf{D} = (\mathbf{X}, \mathbf{y})$, assumed to contain some additive Gaussian noise $\epsilon \sim \mathcal{N}(0, \sigma^2)$, the log marginal likelihood of the dataset is defined as: # # $$ \begin{aligned} # \log(p(\mathbf{y} | \mathbf{X}, \boldsymbol{\theta})) &= \log\left(\int p(\mathbf{y} | \mathbf{f}, \mathbf{X}, \boldsymbol{\theta}) p(\mathbf{f} | \mathbf{X}, \boldsymbol{\theta}) d\mathbf{f}\right) \nonumber \\ @@ -97,6 +115,12 @@ # \end{aligned}$$ # %% [markdown] +# This expression can then be maximised with respect to the hyperparameters using a +# gradient-based approach such as Adam or L-BFGS. Note that we may choose to fix some +# hyperparameters, and in GPJax the parameter $\nu$ is set by the user, and not +# inferred though optimisation. For more details on using the log marginal likelihood to +# optimise kernel hyperparameters, see our [GP introduction notebook](https://docs.jaxgaussianprocesses.com/examples/intro_to_gps/#gaussian-process-regression). +# # We'll demonstrate the advantages of being able to infer kernel parameters from the training data by fitting a GP to the widely used [Forrester function](https://www.sfu.ca/~ssurjano/forretal08.html): # # $$f(x) = (6x - 2)^2 \sin(12x - 4)$$ @@ -104,7 +128,7 @@ # %% # Forrester function -def forrester(x): +def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: return (6 * x - 2) ** 2 * jnp.sin(12 * x - 4) @@ -136,7 +160,7 @@ def forrester(x): no_opt_posterior = prior * likelihood # %% [markdown] -# We can then optimise the hyperparmeters by minimising the negative log marginal likelihood of the data: +# We can then optimise the hyperparameters by minimising the negative log marginal likelihood of the data: # %% negative_mll = gpx.objectives.ConjugateMLL(negative=True) @@ -154,6 +178,11 @@ def forrester(x): ) +# %% [markdown] +# Having optimised the hyperparameters, we can now make predictions using the posterior +# with the optimised hyperparameters, and compare them to the predictions made using the +# posterior with the default hyperparameters: + # %% opt_latent_dist = opt_posterior.predict(test_x, train_data=D) opt_predictive_dist = opt_posterior.likelihood(opt_latent_dist) @@ -299,13 +328,25 @@ def forrester(x): # %% [markdown] # In this guide we have introduced several different kernel functions, $k$, which may make you wonder if any function of two input pairs you construct will make a valid kernel function? Alas, not any function can be used as a kernel function in a GP, and there is a necessary condition a function must satisfy in order to be a valid kernel function. # -# In order to understand the necessary condition, it is useful to introduce the idea of a *Gram matrix*. As introduced in the [GP introduction notebook](https://docs.jaxgaussianprocesses.com/examples/intro_to_gps/), given $n$ input points, $\mathbf{X} = \{\mathbf{x}_1, \ldots, \mathbf{x}_n\}$, the *Gram matrix* is defined as: +# In order to understand the necessary condition, it is useful to introduce the idea of a +# *Gram matrix*. We'll use the same notation as the [GP introduction +# notebook](https://docs.jaxgaussianprocesses.com/examples/intro_to_gps/), and denote +# $n$ input points as $\mathbf{X} = \{\mathbf{x}_1, \ldots, \mathbf{x}_n\}$. Given these +# input points and a kernel function $k$ the *Gram matrix* stores the pairwise kernel +# evaluations between all input points. Mathematically, this leads to the Gram matrix being defined as: # # $$K(\mathbf{X}, \mathbf{X}) = \begin{bmatrix} k(\mathbf{x}_1, \mathbf{x}_1) & \cdots & k(\mathbf{x}_1, \mathbf{x}_n) \\ \vdots & \ddots & \vdots \\ k(\mathbf{x}_n, \mathbf{x}_1) & \cdots & k(\mathbf{x}_n, \mathbf{x}_n) \end{bmatrix}$$ # # such that $K(\mathbf{X}, \mathbf{X})_{ij} = k(\mathbf{x}_i, \mathbf{x}_j)$. # -# In order for $k$ to be a valid kernel/covariance function, the corresponding covariance martrix must be *positive semi-definite*. A real $n \times n$ matrix $K$ is positive semi-definite if and only if for all vectors $\mathbf{z} \in \mathbb{R}^n$, $\mathbf{z}^\top K \mathbf{z} \geq 0$. Alternatively, a real $n \times n$ matrix $K$ is positive semi-definite if and only if all of its eigenvalues are non-negative. +# In order for $k$ to be a valid kernel/covariance function, the corresponding Gram matrix +# must be *positive semi-definite*. In this case the Gram matrix is referred to as a +# *covariance matrix*. A real $n \times n$ matrix $K$ is positive semi-definite if and +# only if for all vectors $\mathbf{z} \in \mathbb{R}^n$: +# +# $$\mathbf{z}^\top K \mathbf{z} \geq 0$$ +# +# Alternatively, a real $n \times n$ matrix $K$ is positive semi-definite if and only if all of its eigenvalues are non-negative. # %% [markdown] # ## Defining Kernels on Non-Euclidean Spaces diff --git a/gpjax/gps.py b/gpjax/gps.py index 1c06927dc..e9739a3f0 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -123,9 +123,7 @@ class Prior(AbstractPrior): function $`k(\cdot, \cdot)`$ is given by $`p(f(\cdot)) = \mathcal{GP}(m(\cdot), k(\cdot, \cdot))`$. - To invoke a `Prior` distribution, only a kernel function is required. By - default, the mean function will be set to zero. In general, this assumption - will be reasonable assuming the data being modelled has been centred. + To invoke a `Prior` distribution, a kernel and mean function must be specified. Example: ```python From 2304a2e44c53c3bf5e66619b35d8f3bc934fd86f Mon Sep 17 00:00:00 2001 From: Thomas-Christie Date: Thu, 6 Jul 2023 10:14:09 +0100 Subject: [PATCH 3/4] Add reference for Matern kernel --- docs/examples/intro_to_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/intro_to_kernels.py b/docs/examples/intro_to_kernels.py index f5384c4a2..230d25e4b 100644 --- a/docs/examples/intro_to_kernels.py +++ b/docs/examples/intro_to_kernels.py @@ -52,7 +52,7 @@ # framework. It is up to the user to select a kernel which is appropriate for the # function being modelled. # -# One of the most widely used families of kernels is the Matérn family. These kernels take on the following form: +# One of the most widely used families of kernels is the Matérn family ([Matérn, 1960](https://core.ac.uk/download/pdf/11698705.pdf)). These kernels take on the following form: # # $$k_{\nu}(\mathbf{x}, \mathbf{x'}) = \sigma^2 \frac{2^{1 - \nu}}{\Gamma(\nu)}\left(\sqrt{2\nu} \frac{|\mathbf{x} - \mathbf{x'}|}{\kappa}\right)^{\nu} K_{\nu} \left(\sqrt{2\nu} \frac{|\mathbf{x} - \mathbf{x'}|}{\kappa}\right)$$ # From 2aae86617f2606f716073e25d0a2d3b390874cec Mon Sep 17 00:00:00 2001 From: Thomas-Christie Date: Fri, 7 Jul 2023 18:13:20 +0100 Subject: [PATCH 4/4] Add edits to kernel introduction notebook Added an extra section to the kernel introduction notebook detailing how one can create new kernels by adding/multiplying two existing kernels. Also added an example using the Mauna Loa CO2 dataset. Also renamed the original 'kernels.py' notebook to 'constructing_new_kernels.py' and edited references to this notebook. --- ...kernels.py => constructing_new_kernels.py} | 0 docs/examples/intro_to_kernels.py | 358 ++++++++++++++++-- docs/sharp_bits.md | 2 +- mkdocs.yml | 2 +- 4 files changed, 337 insertions(+), 25 deletions(-) rename docs/examples/{kernels.py => constructing_new_kernels.py} (100%) diff --git a/docs/examples/kernels.py b/docs/examples/constructing_new_kernels.py similarity index 100% rename from docs/examples/kernels.py rename to docs/examples/constructing_new_kernels.py diff --git a/docs/examples/intro_to_kernels.py b/docs/examples/intro_to_kernels.py index 230d25e4b..d447a4d53 100644 --- a/docs/examples/intro_to_kernels.py +++ b/docs/examples/intro_to_kernels.py @@ -17,11 +17,13 @@ import matplotlib as mpl import matplotlib.pyplot as plt import optax as ox +import pandas as pd from docs.examples.utils import clean_legend with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx from gpjax.typing import Array +from sklearn.preprocessing import StandardScaler key = jr.PRNGKey(42) plt.style.use( @@ -37,28 +39,96 @@ # In this notebook we'll develop some intuition for what kinds of priors are encoded through the use of different kernels, and how this can be useful when modelling different types of functions. # %% [markdown] -# ## Introducing a Common Family of Kernels - The Matérn Family - -# %% [markdown] +# ## What is a Kernel? +# # Intuitively, for a function $f$, the kernel defines the notion of *similarity* between # the value of the function at two points, $f(\mathbf{x})$ and $f(\mathbf{x}')$, and # will be denoted as $k(\mathbf{x}, \mathbf{x}')$: # -# $$k(\mathbf{x}, \mathbf{x}') = \text{Cov}[f(\mathbf{x}), f(\mathbf{x}')]$$ +# $$\begin{aligned} k(\mathbf{x}, \mathbf{x}') &= \text{Cov}[f(\mathbf{x}), +# f(\mathbf{x}')] \\ &= \mathbb{E}[(f(\mathbf{x}) - \mathbb{E}[f(\mathbf{x})])(f(\mathbf{x}') - \mathbb{E}[f(\mathbf{x}')])] \end{aligned}$$ # # One would expect that, given a previously unobserved test point $\mathbf{x}^*$, the # training points which are *closest* to this unobserved point will be most similar to # it. As such, the kernel is used to define this notion of similarity within the GP -# framework. It is up to the user to select a kernel which is appropriate for the -# function being modelled. +# framework. It is up to the user to select a kernel function which is appropriate for +# the function being modelled. In this notebook we are going to give some examples of +# commonly used kernels, and try to develop an understanding of when one may wish to use +# one kernel over another. However, before we do this, it is worth discussing the +# necessary conditions for a function to be a valid kernel/covariance function. This +# requires a little bit of maths, so for those of you who just wish to obtain an +# intuitive understanding, feel free to skip to the section introducing the Matérn +# family of kernels. +# +# ### What are the necessary conditions for a function to be a valid kernel? +# +# Whilst intuitively the kernel function is used to define the notion of similarity within +# the GP framework, it is important to note that there are two *necessary conditions* +# that a kernel function must satisfy in order to be a valid covariance function. For +# clarity, we will refer to *any* function mapping two inputs to a scalar output as a +# *kernel function*, and we will refer to a *valid* kernel function satisfying the two +# necessary conditions as a *covariance function*. However, it is worth noting that the +# GP community often uses the terms *kernel function* and *covariance function* +# interchangeably. +# +# The first necessary condition is that the covariance function must be *symmetric*, i.e. +# $k(\mathbf{x}, \mathbf{x}') = k(\mathbf{x}', \mathbf{x})$. This is because the +# covariance between two random variables $X$ and $X'$ is symmetric; if one looks at the +# definition of covariance given above, it is clear that it is invariant to swapping the +# order of the inputs $\mathbf{x}$ and $\mathbf{x}'$. +# +# The second necessary condition is that the covariance function must be *positive +# semi-definite* (PSD). In order to understand this condition, it is useful to first +# introduce the concept of a *Gram matrix*. We'll use the same notation as the [GP introduction +# notebook](https://docs.jaxgaussianprocesses.com/examples/intro_to_gps/), and denote +# $n$ input points as $\mathbf{X} = \{\mathbf{x}_1, \ldots, \mathbf{x}_n\}$. Given these +# input points and a kernel function $k$ the *Gram matrix* stores the pairwise kernel +# evaluations between all input points. Mathematically, this leads to the Gram matrix being defined as: +# +# $$K(\mathbf{X}, \mathbf{X}) = \begin{bmatrix} k(\mathbf{x}_1, \mathbf{x}_1) & \cdots & k(\mathbf{x}_1, \mathbf{x}_n) \\ \vdots & \ddots & \vdots \\ k(\mathbf{x}_n, \mathbf{x}_1) & \cdots & k(\mathbf{x}_n, \mathbf{x}_n) \end{bmatrix}$$ +# +# such that $K(\mathbf{X}, \mathbf{X})_{ij} = k(\mathbf{x}_i, \mathbf{x}_j)$. +# +# In order for $k$ to be a valid covariance function, the corresponding Gram matrix +# must be *positive semi-definite*. In this case the Gram matrix is referred to as a +# *covariance matrix*. A real $n \times n$ matrix $K$ is positive semi-definite if and +# only if for all vectors $\mathbf{z} \in \mathbb{R}^n$: +# +# $$\mathbf{z}^\top K \mathbf{z} \geq 0$$ +# +# Alternatively, a real $n \times n$ matrix $K$ is positive semi-definite if and only if +# all of its eigenvalues are non-negative. # +# Therefore, the two necessary conditions for a function to be a valid covariance function +# are that it must be *symmetric* and *positive semi-definite*. In this section we have +# referred to *any* function from two inputs to a scalar output as a *kernel function*, +# with its corresponding matrix of pairwise evaluations referred to as the *Gram matrix*, +# and a function satisfying the two necessary conditions as a *covariance function*, with +# its corresponding matrix of pairwise evaluations referred to as the *covariance matrix*. +# This enabled us to easily define the necessary conditions for a function to be a valid +# covariance function. However, as noted previously, the GP community often uses these +# terms interchangeably, and so we will for the remainder of this notebook. +# + +# %% [markdown] +# ## Introducing a Common Family of Kernels - The Matérn Family + +# %% [markdown] # One of the most widely used families of kernels is the Matérn family ([Matérn, 1960](https://core.ac.uk/download/pdf/11698705.pdf)). These kernels take on the following form: # # $$k_{\nu}(\mathbf{x}, \mathbf{x'}) = \sigma^2 \frac{2^{1 - \nu}}{\Gamma(\nu)}\left(\sqrt{2\nu} \frac{|\mathbf{x} - \mathbf{x'}|}{\kappa}\right)^{\nu} K_{\nu} \left(\sqrt{2\nu} \frac{|\mathbf{x} - \mathbf{x'}|}{\kappa}\right)$$ # # where $K_{\nu}$ is a modified Bessel function, $\nu$, $\kappa$ and $\sigma^2$ are # hyperparameters specifying the mean-square differentiability, lengthscale and -# variability respectively, and $|\cdot|$ is used to denote the Euclidean norm. +# variability respectively, and $|\cdot|$ is used to denote the Euclidean norm. Note that +# for those of you less interested in the mathematical underpinnings of kernels, it isn't +# necessary to understand the exact functional form of the Matérn kernels to +# gain an understanding of how they behave. The key takeaway is that they are +# parameterised by several hyperparameters, and that these hyperparameters dictate the +# behaviour of functions sampled from the corresponding GP. The plots below will provide +# some more intuition for how these hyperparameters affect the behaviour of functions +# sampled from the corresponding GP. +# # # Some commonly used Matérn kernels use half-integer values of $\nu$, such as $\nu = 1/2$ # or $\nu = 3/2$. The fraction is sometimes omitted when naming the kernel, so that $\nu = @@ -323,30 +393,272 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: plt.show() # %% [markdown] -# ## What are the Necessary Conditions for a Valid Kernel? +# ## Composing Kernels + +# %% [markdown] +# It is also mathematically valid to compose kernels through operations such as addition +# and multiplication in order to produce more expressive kernels. For the mathematically +# interested amongst you, this is valid as the resulting kernel functions still satisfy +# the necessary conditions introduced at the [start of this +# notebook](#what-are-the-necessary-conditions-for-a-function-to-be-a-valid-kernel). +# Adding or multiplying kernel functions is equivalent to performing elementwise addition +# or multiplication of the corresponding covariance matrices, and fortunately symmetric, +# positive semi-definite kernels are closed under these operations. This means that +# kernels produced by adding or multiplying other kernels will also be symmetric and +# positive semi-definite, and so will also be valid kernels. GPJax provides the +# functionality required to easily compose kernels via addition and multiplication, which +# we'll demonstrate below. +# +# First, we'll take a look at some samples drawn from a GP prior using a kernel which is +# composed of the sum of a linear kernel and a periodic kernel: + +# %% +kernel_one = gpx.kernels.Linear() +kernel_two = gpx.kernels.Periodic() +sum_kernel = gpx.kernels.SumKernel(kernels=[kernel_one, kernel_two]) +mean = gpx.mean_functions.Zero() +prior = gpx.Prior(mean_function=mean, kernel=sum_kernel) + +x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) +rv = prior(x) +y = rv.sample(seed=key, sample_shape=(10,)) +fig, ax = plt.subplots() +ax.plot(x, y.T, alpha=0.7) +ax.set_title("Samples from a GP Prior with Kernel = Linear + Periodic") +plt.show() + # %% [markdown] -# In this guide we have introduced several different kernel functions, $k$, which may make you wonder if any function of two input pairs you construct will make a valid kernel function? Alas, not any function can be used as a kernel function in a GP, and there is a necessary condition a function must satisfy in order to be a valid kernel function. +# We can see that the samples drawn behave as one would naturally expect through adding +# the two kernels together. In particular, the samples are still periodic, as with the +# periodic kernel, but their mean also linearly increases/decreases as they move away from +# the origin, as seen with the linear kernel. # -# In order to understand the necessary condition, it is useful to introduce the idea of a -# *Gram matrix*. We'll use the same notation as the [GP introduction -# notebook](https://docs.jaxgaussianprocesses.com/examples/intro_to_gps/), and denote -# $n$ input points as $\mathbf{X} = \{\mathbf{x}_1, \ldots, \mathbf{x}_n\}$. Given these -# input points and a kernel function $k$ the *Gram matrix* stores the pairwise kernel -# evaluations between all input points. Mathematically, this leads to the Gram matrix being defined as: +# Below we take a look at some samples drawn from a GP prior using a kernel which is +# composed of the same two kernels, but this time multiplied together: + +# %% +kernel_one = gpx.kernels.Linear() +kernel_two = gpx.kernels.Periodic() +sum_kernel = gpx.kernels.ProductKernel(kernels=[kernel_one, kernel_two]) +mean = gpx.mean_functions.Zero() +prior = gpx.Prior(mean_function=mean, kernel=sum_kernel) + +x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) +rv = prior(x) +y = rv.sample(seed=key, sample_shape=(10,)) +fig, ax = plt.subplots() +ax.plot(x, y.T, alpha=0.7) +ax.set_title("Samples from a GP with Kernel = Linear x Periodic") +plt.show() + + +# %% [markdown] +# Once again, the samples drawn behave as one would naturally expect through multiplying +# the two kernels together. In particular, the samples are still periodic but their mean +# linearly increases/decreases as they move away from the origin, and the amplitude of +# the oscillations also linearly increases with increasing distance from the origin. + +# %% [markdown] +# ## Putting it All Together on a Real-World Dataset + +# %% [markdown] +# ### Mauna Loa CO2 Dataset + +# %% [markdown] +# We'll put together some of the ideas we've discussed in this notebook by fitting a GP +# to the [Mauna Loa CO2 dataset](https://www.esrl.noaa.gov/gmd/ccgg/trends/data.html). +# This dataset measures atmospheric CO2 concentration at the Mauna Loa Observatory in +# Hawaii, and is widely used in the GP literature. It contains monthly CO2 readings +# starting in March 1958. Interestingly, there was an eruption at the Mauna Loa volcano in +# November 2022, so readings from December 2022 have changed to a site roughly 21 miles +# North of the Mauna Loa Observatory. We'll use the data from March 1958 to November 2022, +# and see how our GP extrapolates to 8 years before and after the data in the training +# set. +# +# First we'll load the data and plot it: + +# %% +co2_data = pd.read_csv( + "https://gml.noaa.gov/webdata/ccgg/trends/co2/co2_mm_mlo.csv", comment="#" +) +co2_data = co2_data.loc[co2_data["decimal date"] < 2022 + 11 / 12] +train_x = co2_data["decimal date"].values[:, None] +train_y = co2_data["average"].values[:, None] + +fig, ax = plt.subplots() +ax.plot(train_x, train_y) +ax.set_title("CO2 Concentration in the Atmosphere") +ax.set_xlabel("Year") +ax.set_ylabel("CO2 Concentration (ppm)") +plt.show() + +# %% [markdown] +# Looking at the data, we can see that there is clearly a periodic trend, with a period of +# roughly 1 year. We can also see that the data is increasing over time, which is +# also expected. This looks roughly linear, although it may have a non-linear component. +# This information will be useful when we come to choose our kernel. # -# $$K(\mathbf{X}, \mathbf{X}) = \begin{bmatrix} k(\mathbf{x}_1, \mathbf{x}_1) & \cdots & k(\mathbf{x}_1, \mathbf{x}_n) \\ \vdots & \ddots & \vdots \\ k(\mathbf{x}_n, \mathbf{x}_1) & \cdots & k(\mathbf{x}_n, \mathbf{x}_n) \end{bmatrix}$$ +# First, we'll construct our GPJax dataset, and will standardise the outputs, to match our +# assumption that the data has zero mean. + +# %% +test_x = jnp.linspace(1950, 2030, 5000, dtype=jnp.float64).reshape(-1, 1) +y_scaler = StandardScaler().fit(train_y) +standardised_train_y = y_scaler.transform(train_y) + +D = gpx.Dataset(X=train_x, y=standardised_train_y) + +# %% [markdown] +# Having constructed our dataset, we'll now define our kernel. We'll use a kernel which is +# composed of the sum of a linear kernel and a periodic kernel, as we saw in the previous +# section that this kernel is able to capture both the periodic and linear trends in the +# data. We'll also add an RBF kernel to the sum, which will allow us to capture any +# non-linear trends in the data: # -# such that $K(\mathbf{X}, \mathbf{X})_{ij} = k(\mathbf{x}_i, \mathbf{x}_j)$. +# $$\text{Kernel = Linear + Periodic + RBF}$$ # -# In order for $k$ to be a valid kernel/covariance function, the corresponding Gram matrix -# must be *positive semi-definite*. In this case the Gram matrix is referred to as a -# *covariance matrix*. A real $n \times n$ matrix $K$ is positive semi-definite if and -# only if for all vectors $\mathbf{z} \in \mathbb{R}^n$: # -# $$\mathbf{z}^\top K \mathbf{z} \geq 0$$ + +# %% +mean = gpx.mean_functions.Zero() +rbf_kernel = gpx.kernels.RBF(lengthscale=100.0) +periodic_kernel = gpx.kernels.Periodic() +linear_kernel = gpx.kernels.Linear() +sum_kernel = gpx.kernels.SumKernel(kernels=[linear_kernel, periodic_kernel]) +final_kernel = gpx.kernels.SumKernel(kernels=[rbf_kernel, sum_kernel]) + +prior = gpx.Prior(mean_function=mean, kernel=final_kernel) +likelihood = gpx.Gaussian(num_datapoints=D.n) + +posterior = prior * likelihood + +# %% [markdown] +# With our model constructed, let's now fit it to the data, by minimising the negative log +# marginal likelihood of the data: + +# %% +negative_mll = gpx.objectives.ConjugateMLL(negative=True) +negative_mll(posterior, train_data=D) +negative_mll = jit(negative_mll) + +opt_posterior, history = gpx.fit( + model=posterior, + objective=negative_mll, + train_data=D, + optim=ox.adam(learning_rate=0.01), + num_iters=1000, + safe=True, + key=key, +) + +# %% [markdown] +# Now we can obtain the model's prediction over a period of time which includes the +# training data, as well as 8 years before and after the training data: + +# %% +latent_dist = opt_posterior.predict(test_x, train_data=D) +predictive_dist = opt_posterior.likelihood(latent_dist) + +predictive_mean = predictive_dist.mean().reshape(-1, 1) +predictive_std = predictive_dist.stddev().reshape(-1, 1) + +# %% [markdown] +# Let's plot the model's predictions over this period of time: + +# %% +fig, ax = plt.subplots(figsize=(10, 5)) +ax.plot( + train_x, standardised_train_y, "x", label="Observations", color=cols[0], alpha=0.5 +) +ax.fill_between( + test_x.squeeze(), + predictive_mean.squeeze() - 2 * predictive_std.squeeze(), + predictive_mean.squeeze() + 2 * predictive_std.squeeze(), + alpha=0.2, + label="Two sigma", + color=cols[1], +) +ax.plot( + test_x, + predictive_mean - 2 * predictive_std, + linestyle="--", + linewidth=1, + color=cols[1], +) +ax.plot( + test_x, + predictive_mean + 2 * predictive_std, + linestyle="--", + linewidth=1, + color=cols[1], +) +ax.plot(test_x, predictive_mean, label="Predictive mean", color=cols[1]) +ax.set_xlabel("Year") +ax.legend(loc="center left", bbox_to_anchor=(0.975, 0.5)) + +# %% [markdown] +# We can see that the model seems to have captured the periodic trend in the data, as well +# as the (roughly) linear trend. This enables our model to make reasonable seeming +# predictions over the 8 years before and after the training data. Let's zoom in on the +# period from 2010 onwards: # -# Alternatively, a real $n \times n$ matrix $K$ is positive semi-definite if and only if all of its eigenvalues are non-negative. + +# %% +fig, ax = plt.subplots(figsize=(10, 5)) +ax.plot( + train_x[train_x >= 2010], + standardised_train_y[train_x >= 2010], + "x", + label="Observations", + color=cols[0], + alpha=0.5, +) +ax.fill_between( + test_x[test_x >= 2010].squeeze(), + predictive_mean[test_x >= 2010] - 2 * predictive_std[test_x >= 2010], + predictive_mean[test_x >= 2010] + 2 * predictive_std[test_x >= 2010], + alpha=0.2, + label="Two sigma", + color=cols[1], +) +ax.plot( + test_x[test_x >= 2010], + predictive_mean[test_x >= 2010] - 2 * predictive_std[test_x >= 2010], + linestyle="--", + linewidth=1, + color=cols[1], +) +ax.plot( + test_x[test_x >= 2010], + predictive_mean[test_x >= 2010] + 2 * predictive_std[test_x >= 2010], + linestyle="--", + linewidth=1, + color=cols[1], +) +ax.plot( + test_x[test_x >= 2010], + predictive_mean[test_x >= 2010], + label="Predictive mean", + color=cols[1], +) +ax.set_xlabel("Year") +ax.legend(loc="center left", bbox_to_anchor=(0.975, 0.5)) + +# %% [markdown] +# This certainly looks like a reasonable fit to the data, with sensible extrapolation +# beyond the training data, which finishes in November 2022. Moreover, the learned +# parameters of the kernel are interpretable. Let's take a look at the learned period of the periodic kernel: + +# %% +print( + f"Periodic Kernel Period: {[i for i in opt_posterior.prior.kernel.kernels if isinstance(i, gpx.kernels.Periodic)][0].period}" +) + +# %% [markdown] +# This tells us that the periodic trend learned has a period of $\approx 1$. This makes +# intuitive sense, as the unit of the input data is years, and we can see that the +# periodic trend tends to repeat itself roughly every year! # %% [markdown] # ## Defining Kernels on Non-Euclidean Spaces diff --git a/docs/sharp_bits.md b/docs/sharp_bits.md index 5b10fc5e1..91d544814 100644 --- a/docs/sharp_bits.md +++ b/docs/sharp_bits.md @@ -91,7 +91,7 @@ their own bijectors and attach them to the parameter(s) of their model. ### Why is positive-definiteness important? The Gram matrix of a kernel, a concept that we explore more in our -[kernels notebook](examples/kernels.py) and our [PyTree notebook](examples/pytrees.md), is a +[kernels notebook](examples/constructing_new_kernels.py) and our [PyTree notebook](examples/pytrees.md), is a symmetric positive definite matrix. As such, we have a range of tools at our disposal to make subsequent operations on the covariance matrix faster. One of these tools is the Cholesky factorisation that uniquely decomposes diff --git a/mkdocs.yml b/mkdocs.yml index dc091a2ea..b571a70b9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -29,7 +29,7 @@ nav: - Stochastic sparse GPs: examples/collapsed_vi.py - Pathwise Sampling for Spatial Modelling: examples/spatial.py - 📖 Guides for customisation: - - Kernels: examples/kernels.py + - Kernels: examples/constructing_new_kernels.py - Likelihoods: examples/likelihoods_guide.py - UCI regression: examples/yacht.py - 💻 Raw tutorial code: give_me_the_code.md