From a08f6821b814c3095919c8066e4db59f659436be Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 25 Jan 2024 14:58:29 +0200 Subject: [PATCH 1/3] increase number of datapoints to fix inference "issue" --- docs/examples/intro_to_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/intro_to_kernels.py b/docs/examples/intro_to_kernels.py index 55113be61..f97603a1b 100644 --- a/docs/examples/intro_to_kernels.py +++ b/docs/examples/intro_to_kernels.py @@ -201,7 +201,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: return (6 * x - 2) ** 2 * jnp.sin(12 * x - 4) -n = 5 +n = 13 training_x = jr.uniform(key=key, minval=0, maxval=1, shape=(n,)).reshape(-1, 1) training_y = forrester(training_x) From e7b7d7f383490913505fcadf389012611e686c25 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 25 Jan 2024 15:00:06 +0200 Subject: [PATCH 2/3] factor out plot_ribbon function --- docs/examples/intro_to_kernels.py | 64 ++++++++----------------------- 1 file changed, 15 insertions(+), 49 deletions(-) diff --git a/docs/examples/intro_to_kernels.py b/docs/examples/intro_to_kernels.py index f97603a1b..9ee60cdd0 100644 --- a/docs/examples/intro_to_kernels.py +++ b/docs/examples/intro_to_kernels.py @@ -235,6 +235,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: negative_mll = gpx.objectives.ConjugateMLL(negative=True) negative_mll(no_opt_posterior, train_data=D) + +# %% opt_posterior, history = gpx.fit_scipy( model=no_opt_posterior, objective=negative_mll, @@ -247,6 +249,17 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # with the optimised hyperparameters, and compare them to the predictions made using the # posterior with the default hyperparameters: +# %% +def plot_ribbon(ax, x, dist, color): + mean = dist.mean() + std = dist.stddev() + ax.plot(x, mean, label="Predictive mean", color=color) + ax.fill_between(x.squeeze(), mean - 2 * std, mean + 2 * std, alpha=0.2, label="Two sigma", color=color) + ax.plot(x, mean - 2 * std, linestyle="--", linewidth=1, color=color) + ax.plot(x, mean + 2 * std, linestyle="--", linewidth=1, color=color) + + + # %% opt_latent_dist = opt_posterior.predict(test_x, train_data=D) opt_predictive_dist = opt_posterior.likelihood(opt_latent_dist) @@ -256,68 +269,21 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(5, 6)) ax1.plot(training_x, training_y, "x", label="Observations", color=cols[0], alpha=0.5) -ax1.fill_between( - test_x.squeeze(), - opt_predictive_mean - 2 * opt_predictive_std, - opt_predictive_mean + 2 * opt_predictive_std, - alpha=0.2, - label="Two sigma", - color=cols[1], -) -ax1.plot( - test_x, - opt_predictive_mean - 2 * opt_predictive_std, - linestyle="--", - linewidth=1, - color=cols[1], -) -ax1.plot( - test_x, - opt_predictive_mean + 2 * opt_predictive_std, - linestyle="--", - linewidth=1, - color=cols[1], -) ax1.plot( test_x, test_y, label="Latent function", color=cols[0], linestyle="--", linewidth=2 ) -ax1.plot(test_x, opt_predictive_mean, label="Predictive mean", color=cols[1]) +plot_ribbon(ax1, test_x, opt_predictive_dist, color=cols[1]) ax1.set_title("Posterior with Hyperparameter Optimisation") ax1.legend(loc="center left", bbox_to_anchor=(0.975, 0.5)) no_opt_latent_dist = no_opt_posterior.predict(test_x, train_data=D) no_opt_predictive_dist = no_opt_posterior.likelihood(no_opt_latent_dist) -no_opt_predictive_mean = no_opt_predictive_dist.mean() -no_opt_predictive_std = no_opt_predictive_dist.stddev() - ax2.plot(training_x, training_y, "x", label="Observations", color=cols[0], alpha=0.5) -ax2.fill_between( - test_x.squeeze(), - no_opt_predictive_mean - 2 * no_opt_predictive_std, - no_opt_predictive_mean + 2 * no_opt_predictive_std, - alpha=0.2, - label="Two sigma", - color=cols[1], -) -ax2.plot( - test_x, - no_opt_predictive_mean - 2 * no_opt_predictive_std, - linestyle="--", - linewidth=1, - color=cols[1], -) -ax2.plot( - test_x, - no_opt_predictive_mean + 2 * no_opt_predictive_std, - linestyle="--", - linewidth=1, - color=cols[1], -) ax2.plot( test_x, test_y, label="Latent function", color=cols[0], linestyle="--", linewidth=2 ) -ax2.plot(test_x, no_opt_predictive_mean, label="Predictive mean", color=cols[1]) +plot_ribbon(ax2, test_x, no_opt_predictive_dist, color=cols[1]) ax2.set_title("Posterior without Hyperparameter Optimisation") ax2.legend(loc="center left", bbox_to_anchor=(0.975, 0.5)) From 0e12dcc72a1283bfd16b3d1a277b2cb7fefcd992 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 25 Jan 2024 15:00:17 +0200 Subject: [PATCH 3/3] make observation markers more visible --- docs/examples/intro_to_kernels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/intro_to_kernels.py b/docs/examples/intro_to_kernels.py index 9ee60cdd0..1519ee014 100644 --- a/docs/examples/intro_to_kernels.py +++ b/docs/examples/intro_to_kernels.py @@ -268,10 +268,10 @@ def plot_ribbon(ax, x, dist, color): opt_predictive_std = opt_predictive_dist.stddev() fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(5, 6)) -ax1.plot(training_x, training_y, "x", label="Observations", color=cols[0], alpha=0.5) ax1.plot( test_x, test_y, label="Latent function", color=cols[0], linestyle="--", linewidth=2 ) +ax1.plot(training_x, training_y, "x", label="Observations", color="k", zorder=5) plot_ribbon(ax1, test_x, opt_predictive_dist, color=cols[1]) ax1.set_title("Posterior with Hyperparameter Optimisation") ax1.legend(loc="center left", bbox_to_anchor=(0.975, 0.5)) @@ -279,10 +279,10 @@ def plot_ribbon(ax, x, dist, color): no_opt_latent_dist = no_opt_posterior.predict(test_x, train_data=D) no_opt_predictive_dist = no_opt_posterior.likelihood(no_opt_latent_dist) -ax2.plot(training_x, training_y, "x", label="Observations", color=cols[0], alpha=0.5) ax2.plot( test_x, test_y, label="Latent function", color=cols[0], linestyle="--", linewidth=2 ) +ax2.plot(training_x, training_y, "x", label="Observations", color="k", zorder=5) plot_ribbon(ax2, test_x, no_opt_predictive_dist, color=cols[1]) ax2.set_title("Posterior without Hyperparameter Optimisation") ax2.legend(loc="center left", bbox_to_anchor=(0.975, 0.5))