Skip to content

Commit

Permalink
Test Gaussian cov is PSD, bump workflows to python 3.10
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Aug 29, 2023
1 parent 1c3f202 commit 177a5d7
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest"]
python-version: ["3.8"]
python-version: ["3.10"]
steps:
- name: Checkout the branch
uses: actions/checkout@v3.5.2
Expand All @@ -24,7 +24,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v3
with:
python-version: "3.8"
python-version: "3.10"

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest"]
python-version: ["3.8"]
python-version: ["3.10"]

steps:
# Grap the latest commit from the branch
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
# Select the Python versions to test against
os: ["ubuntu-latest", "macos-latest"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]
fail-fast: true
steps:
- name: Check out the code
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest"]
python-version: ["3.8"]
python-version: ["3.10"]

steps:
# Grap the latest commit from the branch
Expand Down
2 changes: 1 addition & 1 deletion gpjax/lower_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import cola
import jax.numpy as jnp

# TODO: Add lower_cholesky for other linear operators.
# TODO: Once this functionality is supported in CoLA, remove this.


@cola.dispatch
Expand Down
8 changes: 8 additions & 0 deletions tests/test_gaussian_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def test_array_arguments(n: int) -> None:
assert approx_equal(dist.stddev(), jnp.sqrt(covariance.diagonal()))
assert approx_equal(dist.covariance(), covariance)

assert isinstance(dist.scale, Dense)
assert cola.PSD in dist.scale.annotations

y = jr.uniform(_key, shape=(n,))

tfp_dist = MultivariateNormalFullCovariance(loc=mean, covariance_matrix=covariance)
Expand All @@ -74,9 +77,14 @@ def test_diag_linear_operator(n: int) -> None:
mean = jr.uniform(key_mean, shape=(n,))
diag = jr.uniform(key_diag, shape=(n,))

# We purosely forget to add a PSD annotation to the diagonal matrix.
dist_diag = GaussianDistribution(loc=mean, scale=Diagonal(diag**2))
tfp_dist = MultivariateNormalDiag(loc=mean, scale_diag=diag)

# We check that the PSD annotation is added automatically.
assert isinstance(dist_diag.scale, Diagonal)
assert cola.PSD in dist_diag.scale.annotations

assert approx_equal(dist_diag.mean(), tfp_dist.mean())
assert approx_equal(dist_diag.mode(), tfp_dist.mode())
assert approx_equal(dist_diag.entropy(), tfp_dist.entropy())
Expand Down
3 changes: 2 additions & 1 deletion tests/test_kernels/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from gpjax.kernels.computations import (
ConstantDiagonalKernelComputation,
DiagonalKernelComputation,
BasisFunctionComputation,
)
from gpjax.kernels.nonstationary import (
Linear,
Expand Down Expand Up @@ -77,4 +78,4 @@ def test_change_computation(kernel):
assert jnp.allclose(constant_entries, constant_entries[0])

# All the off diagonal entries should be zero
assert jnp.allclose(constant_diagonal_matrix - jnp.diag(constant_entries), 0.0)
assert jnp.allclose(constant_diagonal_matrix - jnp.diag(constant_entries), 0.0)

0 comments on commit 177a5d7

Please sign in to comment.