diff --git a/docs/_static/bijector_figure.svg b/docs/_static/bijector_figure.svg
new file mode 100644
index 000000000..5739003e5
--- /dev/null
+++ b/docs/_static/bijector_figure.svg
@@ -0,0 +1,813 @@
+
+
+
diff --git a/docs/_static/step_size_figure.svg b/docs/_static/step_size_figure.svg
new file mode 100644
index 000000000..f94859275
--- /dev/null
+++ b/docs/_static/step_size_figure.svg
@@ -0,0 +1,559 @@
+
+
+
diff --git a/docs/refs.bib b/docs/refs.bib
index 6c62450cc..1889d1888 100644
--- a/docs/refs.bib
+++ b/docs/refs.bib
@@ -112,3 +112,15 @@ @inproceedings{wilson2020efficient
numpages = {11},
series = {ICML'20}
}
+
+@book{higham2022accuracy,
+ author = {Higham, Nicholas J.},
+ title = {Accuracy and Stability of Numerical Algorithms},
+ publisher = {Society for Industrial and Applied Mathematics},
+ year = {2002},
+ doi = {10.1137/1.9780898718027},
+ address = {},
+ edition = {Second},
+ url = {https://epubs.siam.org/doi/abs/10.1137/1.9780898718027},
+ eprint = {https://epubs.siam.org/doi/pdf/10.1137/1.9780898718027}
+}
diff --git a/docs/scripts/sharp_bits_figure.py b/docs/scripts/sharp_bits_figure.py
new file mode 100644
index 000000000..dce70a8f3
--- /dev/null
+++ b/docs/scripts/sharp_bits_figure.py
@@ -0,0 +1,93 @@
+# ---
+# jupyter:
+# jupytext:
+# cell_metadata_filter: -all
+# custom_cell_magics: kql
+# text_representation:
+# extension: .py
+# format_name: percent
+# format_version: '1.3'
+# jupytext_version: 1.11.2
+# kernelspec:
+# display_name: gpjax_baselines
+# language: python
+# name: python3
+# ---
+
+# %%
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+import matplotlib.patches as patches
+
+plt.style.use("../examples/gpjax.mplstyle")
+cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
+
+# %%
+fig, ax = plt.subplots()
+ax.axhline(y = 0.25, color=cols[0], linewidth=1.5)
+
+xs = [0.02, 0.06, 0.1, 0.17]
+ys = np.ones_like(xs) * 0.25
+
+ax.scatter(xs, ys, color=cols[1], marker="o", s=100, zorder=2)
+
+for idx, x in enumerate(xs):
+ ax.annotate(text = f'$\ell_{{t-{idx+1}}}$', xy=(x, 0.25), xytext=(x+0.01, 0.275), ha='center', va='bottom')
+
+
+style = "Simple, tail_width=0.5, head_width=4, head_length=8"
+kw = dict(arrowstyle=style, color="k")
+
+for i in range(len(xs)-1):
+ a = patches.FancyArrowPatch((xs[i+1], 0.25), (xs[i], 0.25), connectionstyle="arc3,rad=-.5", **kw)
+ ax.add_patch(a)
+
+
+ax.scatter(-0.03, 0.25, color=cols[1], marker="x", s=100, linewidth=5, zorder=2)
+
+a = patches.FancyArrowPatch((xs[0], 0.25), (-0.03, 0.25), connectionstyle="arc3,rad=-.5", **kw)
+ax.add_patch(a)
+
+ax.axvline(x = 0, color='black', linewidth=0.5, linestyle='-.')
+ax.get_yaxis().set_visible(False)
+ax.spines["left"].set_visible(False)
+ax.set_ylim(0., 0.5)
+ax.set_xlim(-0.07, 0.25)
+plt.savefig('../_static/step_size_figure.svg', bbox_inches='tight')
+
+# %%
+import tensorflow_probability.substrates.jax.bijectors as tfb
+import jax.numpy as jnp
+
+bij = tfb.Exp()
+
+x = np.linspace(0.05, 3., 6)
+y = np.asarray(bij.inverse(x))
+lval = 0.5
+rval = 0.52
+
+fig, ax = plt.subplots()
+ax.scatter(x, np.ones_like(x)*lval, s=100, label='Constrained value')
+ax.scatter(y, np.ones_like(y)*rval, marker='o', s=100, label='Unconstrained value')
+
+style = "Simple, tail_width=0.25, head_width=2, head_length=8"
+for i in range(len(x)):
+ if i%2 != 0:
+ a = patches.FancyArrowPatch((x[i], lval), (y[i], rval), connectionstyle="arc3,rad=-.15", **kw)
+ # a = patches.Arrow(lval, x[i], rval-lval, y[i]-x[i], width=0.05, color='k')
+ else:
+ a = patches.FancyArrowPatch((x[i], lval), (y[i], rval), connectionstyle="arc3,rad=.005", **kw)
+ ax.add_patch(a)
+
+ax.get_yaxis().set_visible(False)
+ax.spines["left"].set_visible(False)
+ax.legend(loc='best')
+# ax.set_ylim(0.1, 0.32)
+plt.savefig('../_static/bijector_figure.svg', bbox_inches='tight')
+
+# %%
+np.log(0.05)
+
+# %%
+x
diff --git a/docs/sharp_bits.md b/docs/sharp_bits.md
index 7484e7d6a..15153e6e8 100644
--- a/docs/sharp_bits.md
+++ b/docs/sharp_bits.md
@@ -2,16 +2,178 @@
## Pseudo-randomness
-Can briefly acknowledge and then point to the Jax docs for more information.
+Libraries like NumPy and Scipy use *stateful* pseudorandom number generators (PRNGs).
+However, the PRNG in JAX is stateless. This means that for a given function, the
+return always returns the same result unless the seed is changed. This is a good thing,
+but it means that we need to be careful when using JAX's PRNGs.
-## Float64
+To examine what it means for a PRNG to be stateful, consider the following example:
-The need for Float64 when inverting the Gram matrix
+```python
+import numpy as np
+import jax.random as jr
+key = jr.PRNGKey(123)
+
+# NumPy
+print('NumPy:')
+print(np.random.random())
+print(np.random.random())
+
+print('\nJAX:')
+print(jr.uniform(key))
+print(jr.uniform(key))
+
+print('\nSplitting key')
+key, subkey = jr.split(key)
+print(jr.uniform(subkey))
+```
+```console
+NumPy:
+0.5194454541172852
+0.9815886617924413
+
+JAX:
+0.95821166
+0.95821166
+
+Splitting key
+0.23886406
+```
+We can see that, in libraries like NumPy, the PRNG key's state is incremented whenever
+a pseudorandom call is made. This can make debugging difficult to manage as it is not
+always clear when a PRNG is being used. In JAX, the PRNG key is not incremented,
+so the same key will always return the same result. This has further positive benefits
+for reproducibility.
+
+GPJax relies on JAX's PRNGs for all random number generation. Whilst we try wherever possible to handle the PRNG key's state for you, care must be taken when defining your own models and inference schemes to ensure that the PRNG key is handled correctly. The [JAX documentation](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers) has an excellent section on this.
+
+## Bijectors
+
+Parameters such as the kernel's lengthscale or variance have their support defined on
+a constrained subset of the real-line. During gradient-based optimisation, as we
+approach the set's boundary, it becomes possible that we could step outside of the
+set's support and introduce a numerical and mathematical error into our model. For
+example, consider the lengthscale parameter $`\ell`$, which we know must be strictly
+positive. If at $`t^{\text{th}}`$ iterate, our current estimate of $`\ell`$ was
+0.02 and our derivative informed us that $`\ell`$ should decrease, then if our
+learning rate is greater is than 0.03, we would end up with a negative variance term.
+We visualise this issue below where the red cross denotes the invalid lengthscale value
+that would be obtained, were we to optimise in the unconstrained parameter space.
+
+![](_static/step_size_figure.svg)
+
+A simple but impractical solution would be to use a tiny learning rate which would
+reduce the possibility of stepping outside of the parameter's support. However, this
+would be incredibly costly and does not eradicate the problem. An alternative solution
+is to apply a functional mapping to the parameter that projects it from a constrained
+subspace of the real-line onto the entire real-line. Here, gradient updates are
+applied in the unconstrained parameter space before transforming the value back to the
+original support of the parameters. Such a transformation is known as a bijection.
+
+![](_static/bijector_figure.svg)
+
+To help understand this, we show the effect of using a log-exp bijector in the above
+figure. We have six points on the positive real line that range from 0.1 to 3 depicted
+by a blue cross. We then apply the bijector by log-transforming the constrained value.
+This gives us the points' unconstrained value which we depict by a red circle. It is
+this value that we apply gradient updates to. When we wish to recover the constrained
+value, we apply the inverse of the bijector, which is the exponential function in this
+case. This gives us back the blue cross.
+
+In GPJax, we supply bijective functions using [Tensorflow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors).
+In our [PyTrees doc](examples/pytrees.md) document, we detail how the user can define
+their own bijectors and attach them to the parameter(s) of their model.
## Positive-definiteness
-The need for jitter in the kernel Gram matrix
+> "Symmetric positive definiteness is one of the highest accolades to which a matrix can aspire" - Nicholas Highman, Accuracy and stability of numerical algorithms [@higham2022accuracy]
+
+### Why is positive-definiteness important?
+
+The Gram matrix of a kernel, a concept that we explore more in our
+[kernels notebook](examples/kernels.py) and our [PyTree notebook](examples/pytrees.md), is a
+symmetric positive definite matrix. As such, we
+have a range of tools at our disposal to make subsequent operations on the covariance
+matrix faster. One of these tools is the Cholesky factorisation that uniquely decomposes
+any symmetric positive-definite matrix $`\mathbf{\Sigma}`$ by
+
+```math
+\begin{align}
+ \mathbf{\Sigma} = \mathbf{L}\mathbf{L}^{\top}\,,
+\end{align}
+```
+where $`\mathbf{L}`$ is a lower triangular matrix.
+
+We make use of this result in GPJax when solving linear systems of equations of the
+form $`\mathbf{A}\boldsymbol{x} = \boldsymbol{b}`$. Whilst seemingly abstract at first,
+such problems are frequently encountered when constructing Gaussian process models. One
+such example is frequently encountered in the regression setting for learning Gaussian
+process kernel hyperparameters. Here we have labels
+$`\boldsymbol{y} \sim \mathcal{N}(f(\boldsymbol{x}), \sigma^2\mathbf{I})`$ with $`f(\boldsymbol{x}) \sim \mathcal{N}(\boldsymbol{0}, \mathbf{K}_{\boldsymbol{xx}})`$ arising from zero-mean
+Gaussian process prior and Gram matrix $`\mathbf{K}_{\boldsymbol{xx}}`$ at the inputs
+$`\boldsymbol{x}`$. Here the marginal log-likelihood comprises the following form
+
+```math
+\begin{align}
+ \log p(\boldsymbol{y}) = 0.5\left(-\boldsymbol{y}^{\top}\left(\mathbf{K}_{\boldsymbol{xx}} + \sigma^2\mathbf{I} \right)^{-1}\boldsymbol{y} -\log\lvert \mathbf{K}_{\boldsymbol{xx}} + \sigma^2\mathbf{I}\rvert -n\log(2\pi)\right) ,
+\end{align}
+```
+
+and the goal of inference is to maximise kernel hyperparameters (contained in the Gram
+matrix $`\mathbf{K}_{\boldsymbol{xx}}`$) and likelihood hyperparameters (contained in the
+noise covariance $`\sigma^2\mathbf{I}`$). Computing the marginal log-likelihood (and its
+gradients), draws our attention to the term
+
+```math
+\begin{align}
+ \underbrace{\left(\mathbf{K}_{\boldsymbol{xx}} + \sigma^2\mathbf{I} \right)^{-1}}_{\mathbf{A}}\boldsymbol{y},
+\end{align}
+```
+
+then we can see a solution can be obtained by solving the corresponding system of
+equations. By working with $`\mathbf{L} = \operatorname{chol}{\mathbf{A}}`$ instead of
+$`\mathbf{A}`$, we save a significant amount of floating-point operations (flops) by
+solving two triangular systems of equations (one for $`\mathbf{L}`$ and another for
+$`\mathbf{L}^{\top}`$) instead of one dense system of equations. Solving two triangular systems
+of equations has complexity $`\mathcal{O}(n^3/6)`$; a vast improvement compared to
+regular solvers that have $`\mathcal{O}(n^3)`$ complexity in the number of datapoints
+$`n`$.
+
+### The Cholesky drawback
+
+While the computational acceleration provided by using Cholesky factors instead of dense
+matrices is hopefully now apparent, an awkward numerical instability _gotcha_ can arise
+due to floating-point rounding errors. When we evaluate a covariance function on a set
+of points that are very _close_ to one another, eigenvalues of the corresponding
+Gram matrix can get very small. So small that after numerical rounding, the
+smallest eigenvalues can become negative-valued. While not truly less than zero, our
+computer thinks they are, which becomes a problem when we want to compute a Cholesky
+factor since this requires that the input matrix is positive-definite. If there are
+negative eigenvalues, then this stipulation has been invalidated.
+
+To resolve this, we apply some numerical _jitter_ to the diagonals of any Gram matrix.
+Typically this is incredibly small, with $`10^{-6}`$ being the system default. However,
+for some problems, this amount may need to be increased.
## Slow-to-evaluate
-More than several thousand data points will require the use of inducing points - don't try and use the ConjugateMLL objective on a million data points.
+Famously, a regular Gaussian process model (as detailed in
+[our regression notebook](examples/regression.py)) will scale cubically in the number of data points.
+Consequently, if you try to fit your Gaussian process model to a data set containing more
+than several thousand data points, then you will likely incur a significant
+computational overhead. In such cases, we recommend using Sparse Gaussian processes to
+alleviate this issue.
+
+Approximately, when the data contains less than 50000 data points, we recommend using
+the uncollapsed evidence lower bound objective [@titsias2009] to optimise the parameters
+of your sparse Gaussian process model. Such a model will scale linearly in the number of
+data points and quadratically in the number of inducing points. We demonstrate its use
+in [our sparse regression notebook](examples/collapsed_vi.py).
+
+For data sets exceeding 50000 data points, even the sparse Gaussian process outlined
+above will become computationally infeasible. In such cases, we recommend using the
+collapsed evidence lower bound objective [@hensman2013gaussian] that allows stochastic
+mini-batch optimisation of the parameters of your sparse Gaussian process model. Such a
+model will scale linearly in the batch size and quadratically in the number of inducing
+points. We demonstrate its use in
+[our sparse stochastic variational inference notebook](examples/uncollapsed_vi.py)
\ No newline at end of file
diff --git a/mkdocs.yml b/mkdocs.yml
index 4ffb331f0..cd8c8c538 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -12,7 +12,7 @@ nav:
- 🛠️ Installation: installation.md
- 🎨 Design principles: design.md
- 🤝 Contributing: contributing.md
- # - 🔪 Sharp bits: sharp_bits.md
+ - 🔪 Sharp bits: sharp_bits.md
- 🌳 GPJax PyTrees: examples/pytrees.md
- 📎 JAX 101 [External]: https://jax.readthedocs.io/en/latest/jax-101/index.html
- 💡 Background:
@@ -86,7 +86,7 @@ plugins:
csl_file: "https://raw.githubusercontent.com/citation-style-language/styles/af38aba0e9b08406c8827abfc888e5f3e3fa1d65/journal-of-the-royal-statistical-society.csl"
cite_inline: true
- mkdocs-jupyter:
- execute: true
+ execute: false
allow_errors: false
# binder: true
# binder_service_name: "gh"