Skip to content

Commit

Permalink
Fixed Polar GP example (#339)
Browse files Browse the repository at this point in the history
  • Loading branch information
trsav authored Jul 31, 2023
1 parent cbfe6c5 commit 1c80261
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,27 +218,21 @@ class Polar(gpx.kernels.AbstractKernel):
period: float = static_field(2 * jnp.pi)
tau: float = param_field(jnp.array([4.0]), bijector=bij)

def __post_init__(self):
self.c = self.period / 2.0

def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
) -> Float[Array, "1"]:
t = angular_distance(x, y, self.c)
K = (1 + self.tau * t / self.c) * jnp.clip(
1 - t / self.c, 0, jnp.inf
) ** self.tau
c = self.period / 2.0
t = angular_distance(x, y, c)
K = (1 + self.tau * t / c) * jnp.clip(1 - t / c, 0, jnp.inf) ** self.tau
return K.squeeze()


# %% [markdown]
# We unpack this now to make better sense of it. In the kernel's initialiser
# we specify the length of a single period. As the underlying
# domain is a circle, this is $2\pi$. Next, we define
# the Kernel's half-period parameter. As the kernel is a `dataclass` and `c` is
# function of `period`, we must define it in the `__post_init__` method.
# Finally, we define the kernel's `__call__`
# function which is a direct implementation of Equation (1).
# domain is a circle, this is $2\pi$. We then define the kernel's `__call__`
# function which is a direct implementation of Equation (1) where we define `c`
# as half the value of `period`.
#
# To constrain $\tau$ to be greater than 4, we use a `Softplus` bijector with a
# clipped lower bound of 4.0. This is done by specifying the `bijector` argument
Expand Down Expand Up @@ -267,11 +261,11 @@ def __call__(
PKern = Polar()
meanf = gpx.mean_functions.Zero()
likelihood = gpx.Gaussian(num_datapoints=n)
circlular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood
circular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood

# Optimise GP's marginal log-likelihood using Adam
opt_posterior, history = gpx.fit(
model=circlular_posterior,
model=circular_posterior,
objective=jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
optim=ox.adamw(learning_rate=0.05),
Expand Down Expand Up @@ -302,6 +296,7 @@ def __call__(
alpha=0.3,
label=r"1 Posterior s.d.",
color=cols[1],
lw=0,
)
ax.fill_between(
angles.squeeze(),
Expand All @@ -310,6 +305,7 @@ def __call__(
alpha=0.15,
label=r"3 Posterior s.d.",
color=cols[1],
lw=0,
)
ax.plot(angles, mu, label="Posterior mean")
ax.scatter(D.X, D.y, alpha=1, label="Observations")
Expand Down

0 comments on commit 1c80261

Please sign in to comment.