Skip to content

Commit

Permalink
Some final rebase adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
ingmarschuster committed Sep 16, 2023
1 parent a336daa commit 9bcc366
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 14 deletions.
1 change: 0 additions & 1 deletion gpjax/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Optional,
Union,
Literal,
Literal,
)
import jax.numpy as jnp
from jaxtyping import (
Expand Down
7 changes: 0 additions & 7 deletions gpjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,6 @@ def log_prob(
mu = self.loc
sigma = self.scale
n = mu.shape[-1]
if mask is not None:
y = jnp.where(mask, 0.0, y)
mu = jnp.where(mask, 0.0, mu)
sigma_masked = jnp.where(mask[None] + mask[:, None], 0.0, sigma.matrix)
sigma = sigma.replace(
matrix=jnp.where(jnp.diag(mask), 1 / (2 * jnp.pi), sigma_masked)
)

if mask is not None:
y = jnp.where(mask, 0.0, y)
Expand Down
6 changes: 3 additions & 3 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,6 @@ def predict(
m = y.shape[1]
if m > 1 and mask is not None:
mask = mask.flatten()
n_X_m = n * m
# Unpack test inputs
t = test_inputs

Expand Down Expand Up @@ -524,8 +523,9 @@ def predict(
)

mean_t = self.prior.mean_function(t)
Ktt = jnp.kron(self.prior.kernel.gram(t).to_dense(), Kyy.to_dense())
Kxt = jnp.kron(self.prior.kernel.cross_covariance(x, t), Kyy.to_dense())
Ktt = cola.ops.Kronecker(self.prior.kernel.gram(t), Kyy)
Ktt = cola.PSD(Ktt)
Kxt = cola.ops.Kronecker(self.prior.kernel.cross_covariance(x, t), Kyy)

# Σ⁻¹ Kxt
if mask is not None:
Expand Down
1 change: 0 additions & 1 deletion gpjax/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def step(
m = y.shape[1]
if m > 1 and mask is not None:
mask = mask.flatten()
n_X_m = n * m

# Observation noise o²
obs_noise = posterior.likelihood.obs_noise
Expand Down
2 changes: 0 additions & 2 deletions tests/test_gaussian_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
)

from gpjax.distributions import GaussianDistribution
from gpjax.linops.dense_linear_operator import DenseLinearOperator
from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator

_key = jr.PRNGKey(seed=42)

Expand Down

0 comments on commit 9bcc366

Please sign in to comment.