Skip to content

Commit

Permalink
Merge pull request #432 from JaxGaussianProcesses/st/kernel_notebook
Browse files Browse the repository at this point in the history
Clean up intro to kernels notebook
  • Loading branch information
thomaspinder authored Jan 25, 2024
2 parents c1599e4 + 0e12dcc commit 991f235
Showing 1 changed file with 18 additions and 52 deletions.
70 changes: 18 additions & 52 deletions docs/examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
return (6 * x - 2) ** 2 * jnp.sin(12 * x - 4)


n = 5
n = 13

training_x = jr.uniform(key=key, minval=0, maxval=1, shape=(n,)).reshape(-1, 1)
training_y = forrester(training_x)
Expand Down Expand Up @@ -235,6 +235,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(no_opt_posterior, train_data=D)


# %%
opt_posterior, history = gpx.fit_scipy(
model=no_opt_posterior,
objective=negative_mll,
Expand All @@ -247,6 +249,17 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# with the optimised hyperparameters, and compare them to the predictions made using the
# posterior with the default hyperparameters:

# %%
def plot_ribbon(ax, x, dist, color):
mean = dist.mean()
std = dist.stddev()
ax.plot(x, mean, label="Predictive mean", color=color)
ax.fill_between(x.squeeze(), mean - 2 * std, mean + 2 * std, alpha=0.2, label="Two sigma", color=color)
ax.plot(x, mean - 2 * std, linestyle="--", linewidth=1, color=color)
ax.plot(x, mean + 2 * std, linestyle="--", linewidth=1, color=color)



# %%
opt_latent_dist = opt_posterior.predict(test_x, train_data=D)
opt_predictive_dist = opt_posterior.likelihood(opt_latent_dist)
Expand All @@ -255,69 +268,22 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
opt_predictive_std = opt_predictive_dist.stddev()

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(5, 6))
ax1.plot(training_x, training_y, "x", label="Observations", color=cols[0], alpha=0.5)
ax1.fill_between(
test_x.squeeze(),
opt_predictive_mean - 2 * opt_predictive_std,
opt_predictive_mean + 2 * opt_predictive_std,
alpha=0.2,
label="Two sigma",
color=cols[1],
)
ax1.plot(
test_x,
opt_predictive_mean - 2 * opt_predictive_std,
linestyle="--",
linewidth=1,
color=cols[1],
)
ax1.plot(
test_x,
opt_predictive_mean + 2 * opt_predictive_std,
linestyle="--",
linewidth=1,
color=cols[1],
)
ax1.plot(
test_x, test_y, label="Latent function", color=cols[0], linestyle="--", linewidth=2
)
ax1.plot(test_x, opt_predictive_mean, label="Predictive mean", color=cols[1])
ax1.plot(training_x, training_y, "x", label="Observations", color="k", zorder=5)
plot_ribbon(ax1, test_x, opt_predictive_dist, color=cols[1])
ax1.set_title("Posterior with Hyperparameter Optimisation")
ax1.legend(loc="center left", bbox_to_anchor=(0.975, 0.5))

no_opt_latent_dist = no_opt_posterior.predict(test_x, train_data=D)
no_opt_predictive_dist = no_opt_posterior.likelihood(no_opt_latent_dist)

no_opt_predictive_mean = no_opt_predictive_dist.mean()
no_opt_predictive_std = no_opt_predictive_dist.stddev()

ax2.plot(training_x, training_y, "x", label="Observations", color=cols[0], alpha=0.5)
ax2.fill_between(
test_x.squeeze(),
no_opt_predictive_mean - 2 * no_opt_predictive_std,
no_opt_predictive_mean + 2 * no_opt_predictive_std,
alpha=0.2,
label="Two sigma",
color=cols[1],
)
ax2.plot(
test_x,
no_opt_predictive_mean - 2 * no_opt_predictive_std,
linestyle="--",
linewidth=1,
color=cols[1],
)
ax2.plot(
test_x,
no_opt_predictive_mean + 2 * no_opt_predictive_std,
linestyle="--",
linewidth=1,
color=cols[1],
)
ax2.plot(
test_x, test_y, label="Latent function", color=cols[0], linestyle="--", linewidth=2
)
ax2.plot(test_x, no_opt_predictive_mean, label="Predictive mean", color=cols[1])
ax2.plot(training_x, training_y, "x", label="Observations", color="k", zorder=5)
plot_ribbon(ax2, test_x, no_opt_predictive_dist, color=cols[1])
ax2.set_title("Posterior without Hyperparameter Optimisation")
ax2.legend(loc="center left", bbox_to_anchor=(0.975, 0.5))

Expand Down

0 comments on commit 991f235

Please sign in to comment.