From 0a62fb4841f87f81d6539a11ca23745caadfda03 Mon Sep 17 00:00:00 2001 From: frazane Date: Thu, 30 Nov 2023 14:52:27 +0100 Subject: [PATCH 1/5] generic types for gps objects --- gpjax/gps.py | 112 +++++++++++++++++++++++++++---------------- gpjax/likelihoods.py | 4 +- 2 files changed, 72 insertions(+), 44 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 5928ef491..a58133859 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +# from __future__ import annotations from abc import abstractmethod from dataclasses import dataclass -from typing import overload +from typing import ( + TYPE_CHECKING, + Generic, + TypeVar, + overload, +) from beartype.typing import ( Any, @@ -47,7 +52,7 @@ from gpjax.likelihoods import ( AbstractLikelihood, Gaussian, - NonGaussianLikelihood, + NonGaussian, ) from gpjax.lower_cholesky import lower_cholesky from gpjax.mean_functions import AbstractMeanFunction @@ -57,13 +62,19 @@ KeyArray, ) +Kernel = TypeVar("Kernel", bound=AbstractKernel) +MeanFunction = TypeVar("MeanFunction", bound=AbstractMeanFunction) +Likelihood = TypeVar("Likelihood", bound=AbstractLikelihood) +NonGaussianLikelihood = TypeVar("NonGaussianLikelihood", bound=NonGaussian) +GaussianLikelihood = TypeVar("GaussianLikelihood", bound=Gaussian) + @dataclass -class AbstractPrior(Module): +class AbstractPrior(Module, Generic[MeanFunction, Kernel]): r"""Abstract Gaussian process prior.""" - kernel: AbstractKernel - mean_function: AbstractMeanFunction + kernel: Kernel + mean_function: MeanFunction jitter: float = static_field(1e-6) def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution: @@ -113,7 +124,7 @@ def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: # GP Priors ####################### @dataclass -class Prior(AbstractPrior): +class Prior(AbstractPrior[MeanFunction, Kernel]): r"""A Gaussian process prior object. The GP is parameterised by a @@ -136,18 +147,27 @@ class Prior(AbstractPrior): >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel) ``` """ - - # @overload - # def __mul__(self, other: Gaussian) -> "ConjugatePosterior": - # ... - - # @overload - # def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior": - # ... - - # @overload - # def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior": - # ... + if TYPE_CHECKING: + + @overload + def __mul__( + self, other: GaussianLikelihood + ) -> "ConjugatePosterior[Prior[MeanFunction, Kernel], GaussianLikelihood]": + ... + + @overload + def __mul__( + self, other: NonGaussianLikelihood + ) -> ( + "NonConjugatePosterior[Prior[MeanFunction, Kernel], NonGaussianLikelihood]" + ): + ... + + @overload + def __mul__( + self, other: Likelihood + ) -> "AbstractPosterior[Prior[MeanFunction, Kernel], Likelihood]": + ... def __mul__(self, other): r"""Combine the prior with a likelihood to form a posterior distribution. @@ -183,17 +203,27 @@ def __mul__(self, other): """ return construct_posterior(prior=self, likelihood=other) - # @overload - # def __rmul__(self, other: Gaussian) -> "ConjugatePosterior": - # ... + if TYPE_CHECKING: + + @overload + def __rmul__( + self, other: GaussianLikelihood + ) -> "ConjugatePosterior[Prior[MeanFunction, Kernel], GaussianLikelihood]": + ... - # @overload - # def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior": - # ... + @overload + def __rmul__( + self, other: NonGaussianLikelihood + ) -> ( + "NonConjugatePosterior[Prior[MeanFunction, Kernel], NonGaussianLikelihood]" + ): + ... - # @overload - # def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior": - # ... + @overload + def __rmul__( + self, other: Likelihood + ) -> "AbstractPosterior[Prior[MeanFunction, Kernel], Likelihood]": + ... def __rmul__(self, other): r"""Combine the prior with a likelihood to form a posterior distribution. @@ -324,19 +354,22 @@ def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]: return sample_fn +PriorType = TypeVar("PriorType", bound=AbstractPrior) + + ####################### # GP Posteriors ####################### @dataclass -class AbstractPosterior(Module): +class AbstractPosterior(Module, Generic[PriorType, Likelihood]): r"""Abstract Gaussian process posterior. The base GP posterior object conditioned on an observed dataset. All posterior objects should inherit from this class. """ - prior: AbstractPrior - likelihood: AbstractLikelihood + prior: AbstractPrior[MeanFunction, Kernel] + likelihood: Likelihood jitter: float = static_field(1e-6) def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution: @@ -381,7 +414,7 @@ def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: @dataclass -class ConjugatePosterior(AbstractPosterior): +class ConjugatePosterior(AbstractPosterior[PriorType, GaussianLikelihood]): r"""A Conjuate Gaussian process posterior object. A Gaussian process posterior distribution when the constituent likelihood @@ -600,7 +633,7 @@ def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]: @dataclass -class NonConjugatePosterior(AbstractPosterior): +class NonConjugatePosterior(AbstractPosterior[PriorType, NonGaussianLikelihood]): r"""A non-conjugate Gaussian process posterior object. A Gaussian process posterior object for models where the likelihood is @@ -685,22 +718,17 @@ def predict( ####################### -@overload -def construct_posterior(prior: Prior, likelihood: Gaussian) -> ConjugatePosterior: - ... - - @overload def construct_posterior( - prior: Prior, likelihood: NonGaussianLikelihood -) -> NonConjugatePosterior: + prior: PriorType, likelihood: GaussianLikelihood +) -> ConjugatePosterior[PriorType, GaussianLikelihood]: ... @overload def construct_posterior( - prior: Prior, likelihood: AbstractLikelihood -) -> AbstractPosterior: + prior: PriorType, likelihood: NonGaussianLikelihood +) -> NonConjugatePosterior[PriorType, NonGaussianLikelihood]: ... diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index b7fcb724b..4f93dfa51 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -247,11 +247,11 @@ def inv_probit(x: Float[Array, " *N"]) -> Float[Array, " *N"]: return 0.5 * (1.0 + jsp.special.erf(x / jnp.sqrt(2.0))) * (1 - 2 * jitter) + jitter -NonGaussianLikelihood = Union[Poisson, Bernoulli] +NonGaussian = Union[Poisson, Bernoulli] __all__ = [ "AbstractLikelihood", - "NonGaussianLikelihood", + "NonGaussian", "Gaussian", "Bernoulli", "Poisson", From a6da54871e15d043f25048ff7b80c5f8d47dfc38 Mon Sep 17 00:00:00 2001 From: frazane Date: Thu, 30 Nov 2023 15:47:20 +0100 Subject: [PATCH 2/5] version bump --- gpjax/__init__.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 6e45ae72d..4a86e0107 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -81,7 +81,7 @@ __description__ = "Didactic Gaussian processes in JAX" __url__ = "https://github.com/JaxGaussianProcesses/GPJax" __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors" -__version__ = "0.7.3" +__version__ = "0.7.4" __all__ = [ "Module", diff --git a/pyproject.toml b/pyproject.toml index 6e2aa4537..0c2c8745b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gpjax" -version = "0.7.3" +version = "0.7.4" description = "Gaussian processes in JAX." authors = [ "Thomas Pinder ", From 0a8ffac0c288ada35667c2f8de803d1f64a98d6c Mon Sep 17 00:00:00 2001 From: frazane Date: Thu, 30 Nov 2023 16:22:09 +0100 Subject: [PATCH 3/5] v 0.8.0 --- gpjax/__init__.py | 2 +- gpjax/gps.py | 34 ++++++++++++++++-------------- gpjax/kernels/base.py | 3 ++- gpjax/kernels/computations/base.py | 4 ++-- pyproject.toml | 2 +- 5 files changed, 24 insertions(+), 21 deletions(-) diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 4a86e0107..c48e204d2 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -81,7 +81,7 @@ __description__ = "Didactic Gaussian processes in JAX" __url__ = "https://github.com/JaxGaussianProcesses/GPJax" __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors" -__version__ = "0.7.4" +__version__ = "0.8.0" __all__ = [ "Module", diff --git a/gpjax/gps.py b/gpjax/gps.py index a58133859..aa262d894 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -28,7 +28,9 @@ Optional, Union, ) -import cola +from cola.linalg.inverse.inv import solve +from cola.annotations import PSD +from cola.ops.operators import I_like from cola.linalg.decompositions.decompositions import Cholesky import jax.numpy as jnp from jax.random import ( @@ -275,8 +277,8 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution: x = test_inputs mx = self.mean_function(x) Kxx = self.kernel.gram(x) - Kxx += cola.ops.I_like(Kxx) * self.jitter - Kxx = cola.PSD(Kxx) + Kxx += I_like(Kxx) * self.jitter + Kxx = PSD(Kxx) return GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Kxx) @@ -522,24 +524,24 @@ def predict( # Precompute Gram matrix, Kxx, at training inputs, x Kxx = self.prior.kernel.gram(x) - Kxx += cola.ops.I_like(Kxx) * self.jitter + Kxx += I_like(Kxx) * self.jitter # Σ = Kxx + Io² - Sigma = Kxx + cola.ops.I_like(Kxx) * obs_noise - Sigma = cola.PSD(Sigma) + Sigma = Kxx + I_like(Kxx) * obs_noise + Sigma = PSD(Sigma) mean_t = self.prior.mean_function(t) Ktt = self.prior.kernel.gram(t) Kxt = self.prior.kernel.cross_covariance(x, t) - Sigma_inv_Kxt = cola.solve(Sigma, Kxt) + Sigma_inv_Kxt = solve(Sigma, Kxt) # μt + Ktx (Kxx + Io²)⁻¹ (y - μx) mean = mean_t + jnp.matmul(Sigma_inv_Kxt.T, y - mx) # Ktt - Ktx (Kxx + Io²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently. covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt) - covariance += cola.ops.I_like(covariance) * self.prior.jitter - covariance = cola.PSD(covariance) + covariance += I_like(covariance) * self.prior.jitter + covariance = PSD(covariance) return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) @@ -601,11 +603,11 @@ def sample_approx( # v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Io² and ε ᯈ N(0, o²) obs_var = self.likelihood.obs_stddev**2 Kxx = self.prior.kernel.gram(train_data.X) # [N, N] - Sigma = Kxx + cola.ops.I_like(Kxx) * (obs_var + self.jitter) # [N, N] + Sigma = Kxx + I_like(Kxx) * (obs_var + self.jitter) # [N, N] eps = jnp.sqrt(obs_var) * normal(key, [train_data.n, num_samples]) # [N, B] y = train_data.y - self.prior.mean_function(train_data.X) # account for mean Phi = fourier_feature_fn(train_data.X) - canonical_weights = cola.solve( + canonical_weights = solve( Sigma, y + eps - jnp.inner(Phi, fourier_weights), Cholesky(), @@ -684,8 +686,8 @@ def predict( # Precompute lower triangular of Gram matrix, Lx, at training inputs, x Kxx = kernel.gram(x) - Kxx += cola.ops.I_like(Kxx) * self.prior.jitter - Kxx = cola.PSD(Kxx) + Kxx += I_like(Kxx) * self.prior.jitter + Kxx = PSD(Kxx) Lx = lower_cholesky(Kxx) # Unpack test inputs @@ -697,7 +699,7 @@ def predict( mean_t = mean_function(t) # Lx⁻¹ Kxt - Lx_inv_Kxt = cola.solve(Lx, Ktx.T, Cholesky()) + Lx_inv_Kxt = solve(Lx, Ktx.T, Cholesky()) # Whitened function values, wx, corresponding to the inputs, x wx = self.latent @@ -707,8 +709,8 @@ def predict( # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently. covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) - covariance += cola.ops.I_like(covariance) * self.prior.jitter - covariance = cola.PSD(covariance) + covariance += I_like(covariance) * self.prior.jitter + covariance = PSD(covariance) return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index ff9e7f8b6..4605ef4a3 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -29,6 +29,7 @@ Num, ) import tensorflow_probability.substrates.jax.distributions as tfd +from cola.ops.operators import LinearOperator from gpjax.base import ( Module, @@ -60,7 +61,7 @@ def ndims(self): def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]): return self.compute_engine.cross_covariance(self, x, y) - def gram(self, x: Num[Array, "N D"]): + def gram(self, x: Num[Array, "N D"]) -> LinearOperator: return self.compute_engine.gram(self, x) def slice_input(self, x: Float[Array, "... D"]) -> Float[Array, "... Q"]: diff --git a/gpjax/kernels/computations/base.py b/gpjax/kernels/computations/base.py index ac48b8101..54636fac0 100644 --- a/gpjax/kernels/computations/base.py +++ b/gpjax/kernels/computations/base.py @@ -17,8 +17,8 @@ from dataclasses import dataclass import typing as tp -from cola import PSD -from cola.ops import ( +from cola.annotations import PSD +from cola.ops.operators import ( Dense, Diagonal, LinearOperator, diff --git a/pyproject.toml b/pyproject.toml index 0c2c8745b..198ccac6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gpjax" -version = "0.7.4" +version = "0.8.0" description = "Gaussian processes in JAX." authors = [ "Thomas Pinder ", From 2ef250b0990d86815d169782095efa5f21fed2d7 Mon Sep 17 00:00:00 2001 From: frazane Date: Thu, 30 Nov 2023 16:41:31 +0100 Subject: [PATCH 4/5] undo dumb move --- gpjax/gps.py | 34 ++++++++++++++---------------- gpjax/kernels/base.py | 3 +-- gpjax/kernels/computations/base.py | 4 ++-- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index aa262d894..a58133859 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -28,9 +28,7 @@ Optional, Union, ) -from cola.linalg.inverse.inv import solve -from cola.annotations import PSD -from cola.ops.operators import I_like +import cola from cola.linalg.decompositions.decompositions import Cholesky import jax.numpy as jnp from jax.random import ( @@ -277,8 +275,8 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution: x = test_inputs mx = self.mean_function(x) Kxx = self.kernel.gram(x) - Kxx += I_like(Kxx) * self.jitter - Kxx = PSD(Kxx) + Kxx += cola.ops.I_like(Kxx) * self.jitter + Kxx = cola.PSD(Kxx) return GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Kxx) @@ -524,24 +522,24 @@ def predict( # Precompute Gram matrix, Kxx, at training inputs, x Kxx = self.prior.kernel.gram(x) - Kxx += I_like(Kxx) * self.jitter + Kxx += cola.ops.I_like(Kxx) * self.jitter # Σ = Kxx + Io² - Sigma = Kxx + I_like(Kxx) * obs_noise - Sigma = PSD(Sigma) + Sigma = Kxx + cola.ops.I_like(Kxx) * obs_noise + Sigma = cola.PSD(Sigma) mean_t = self.prior.mean_function(t) Ktt = self.prior.kernel.gram(t) Kxt = self.prior.kernel.cross_covariance(x, t) - Sigma_inv_Kxt = solve(Sigma, Kxt) + Sigma_inv_Kxt = cola.solve(Sigma, Kxt) # μt + Ktx (Kxx + Io²)⁻¹ (y - μx) mean = mean_t + jnp.matmul(Sigma_inv_Kxt.T, y - mx) # Ktt - Ktx (Kxx + Io²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently. covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt) - covariance += I_like(covariance) * self.prior.jitter - covariance = PSD(covariance) + covariance += cola.ops.I_like(covariance) * self.prior.jitter + covariance = cola.PSD(covariance) return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) @@ -603,11 +601,11 @@ def sample_approx( # v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Io² and ε ᯈ N(0, o²) obs_var = self.likelihood.obs_stddev**2 Kxx = self.prior.kernel.gram(train_data.X) # [N, N] - Sigma = Kxx + I_like(Kxx) * (obs_var + self.jitter) # [N, N] + Sigma = Kxx + cola.ops.I_like(Kxx) * (obs_var + self.jitter) # [N, N] eps = jnp.sqrt(obs_var) * normal(key, [train_data.n, num_samples]) # [N, B] y = train_data.y - self.prior.mean_function(train_data.X) # account for mean Phi = fourier_feature_fn(train_data.X) - canonical_weights = solve( + canonical_weights = cola.solve( Sigma, y + eps - jnp.inner(Phi, fourier_weights), Cholesky(), @@ -686,8 +684,8 @@ def predict( # Precompute lower triangular of Gram matrix, Lx, at training inputs, x Kxx = kernel.gram(x) - Kxx += I_like(Kxx) * self.prior.jitter - Kxx = PSD(Kxx) + Kxx += cola.ops.I_like(Kxx) * self.prior.jitter + Kxx = cola.PSD(Kxx) Lx = lower_cholesky(Kxx) # Unpack test inputs @@ -699,7 +697,7 @@ def predict( mean_t = mean_function(t) # Lx⁻¹ Kxt - Lx_inv_Kxt = solve(Lx, Ktx.T, Cholesky()) + Lx_inv_Kxt = cola.solve(Lx, Ktx.T, Cholesky()) # Whitened function values, wx, corresponding to the inputs, x wx = self.latent @@ -709,8 +707,8 @@ def predict( # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently. covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) - covariance += I_like(covariance) * self.prior.jitter - covariance = PSD(covariance) + covariance += cola.ops.I_like(covariance) * self.prior.jitter + covariance = cola.PSD(covariance) return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index 4605ef4a3..ff9e7f8b6 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -29,7 +29,6 @@ Num, ) import tensorflow_probability.substrates.jax.distributions as tfd -from cola.ops.operators import LinearOperator from gpjax.base import ( Module, @@ -61,7 +60,7 @@ def ndims(self): def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]): return self.compute_engine.cross_covariance(self, x, y) - def gram(self, x: Num[Array, "N D"]) -> LinearOperator: + def gram(self, x: Num[Array, "N D"]): return self.compute_engine.gram(self, x) def slice_input(self, x: Float[Array, "... D"]) -> Float[Array, "... Q"]: diff --git a/gpjax/kernels/computations/base.py b/gpjax/kernels/computations/base.py index 54636fac0..ac48b8101 100644 --- a/gpjax/kernels/computations/base.py +++ b/gpjax/kernels/computations/base.py @@ -17,8 +17,8 @@ from dataclasses import dataclass import typing as tp -from cola.annotations import PSD -from cola.ops.operators import ( +from cola import PSD +from cola.ops import ( Dense, Diagonal, LinearOperator, From d30d783f5440eb178ed64787680303dc2d40381f Mon Sep 17 00:00:00 2001 From: frazane Date: Thu, 30 Nov 2023 18:05:06 +0100 Subject: [PATCH 5/5] exclude type_checking blocks from coverage --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 198ccac6e..b524f9d86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -183,6 +183,7 @@ fail_under = 50 precision = 1 show_missing = true skip_covered = true +exclude_lines = ["if TYPE_CHECKING:"] [tool.coverage.run] # https://coverage.readthedocs.io/en/latest/config.html#run branch = true