Skip to content

Commit

Permalink
Merge pull request #425 from JaxGaussianProcesses/generic_types
Browse files Browse the repository at this point in the history
Add generic typing information for gps objects
  • Loading branch information
thomaspinder authored Nov 30, 2023
2 parents 55ecaac + d30d783 commit 6955957
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 44 deletions.
112 changes: 70 additions & 42 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# from __future__ import annotations
from abc import abstractmethod
from dataclasses import dataclass
from typing import overload
from typing import (
TYPE_CHECKING,
Generic,
TypeVar,
overload,
)

from beartype.typing import (
Any,
Expand Down Expand Up @@ -47,7 +52,7 @@
from gpjax.likelihoods import (
AbstractLikelihood,
Gaussian,
NonGaussianLikelihood,
NonGaussian,
)
from gpjax.lower_cholesky import lower_cholesky
from gpjax.mean_functions import AbstractMeanFunction
Expand All @@ -57,13 +62,19 @@
KeyArray,
)

Kernel = TypeVar("Kernel", bound=AbstractKernel)
MeanFunction = TypeVar("MeanFunction", bound=AbstractMeanFunction)
Likelihood = TypeVar("Likelihood", bound=AbstractLikelihood)
NonGaussianLikelihood = TypeVar("NonGaussianLikelihood", bound=NonGaussian)
GaussianLikelihood = TypeVar("GaussianLikelihood", bound=Gaussian)


@dataclass
class AbstractPrior(Module):
class AbstractPrior(Module, Generic[MeanFunction, Kernel]):
r"""Abstract Gaussian process prior."""

kernel: AbstractKernel
mean_function: AbstractMeanFunction
kernel: Kernel
mean_function: MeanFunction
jitter: float = static_field(1e-6)

def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution:
Expand Down Expand Up @@ -113,7 +124,7 @@ def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution:
# GP Priors
#######################
@dataclass
class Prior(AbstractPrior):
class Prior(AbstractPrior[MeanFunction, Kernel]):
r"""A Gaussian process prior object.
The GP is parameterised by a
Expand All @@ -136,18 +147,27 @@ class Prior(AbstractPrior):
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
```
"""

# @overload
# def __mul__(self, other: Gaussian) -> "ConjugatePosterior":
# ...

# @overload
# def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
# ...

# @overload
# def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
# ...
if TYPE_CHECKING:

@overload
def __mul__(
self, other: GaussianLikelihood
) -> "ConjugatePosterior[Prior[MeanFunction, Kernel], GaussianLikelihood]":
...

@overload
def __mul__(
self, other: NonGaussianLikelihood
) -> (
"NonConjugatePosterior[Prior[MeanFunction, Kernel], NonGaussianLikelihood]"
):
...

@overload
def __mul__(
self, other: Likelihood
) -> "AbstractPosterior[Prior[MeanFunction, Kernel], Likelihood]":
...

def __mul__(self, other):
r"""Combine the prior with a likelihood to form a posterior distribution.
Expand Down Expand Up @@ -183,17 +203,27 @@ def __mul__(self, other):
"""
return construct_posterior(prior=self, likelihood=other)

# @overload
# def __rmul__(self, other: Gaussian) -> "ConjugatePosterior":
# ...
if TYPE_CHECKING:

@overload
def __rmul__(
self, other: GaussianLikelihood
) -> "ConjugatePosterior[Prior[MeanFunction, Kernel], GaussianLikelihood]":
...

# @overload
# def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
# ...
@overload
def __rmul__(
self, other: NonGaussianLikelihood
) -> (
"NonConjugatePosterior[Prior[MeanFunction, Kernel], NonGaussianLikelihood]"
):
...

# @overload
# def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
# ...
@overload
def __rmul__(
self, other: Likelihood
) -> "AbstractPosterior[Prior[MeanFunction, Kernel], Likelihood]":
...

def __rmul__(self, other):
r"""Combine the prior with a likelihood to form a posterior distribution.
Expand Down Expand Up @@ -324,19 +354,22 @@ def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]:
return sample_fn


PriorType = TypeVar("PriorType", bound=AbstractPrior)


#######################
# GP Posteriors
#######################
@dataclass
class AbstractPosterior(Module):
class AbstractPosterior(Module, Generic[PriorType, Likelihood]):
r"""Abstract Gaussian process posterior.
The base GP posterior object conditioned on an observed dataset. All
posterior objects should inherit from this class.
"""

prior: AbstractPrior
likelihood: AbstractLikelihood
prior: AbstractPrior[MeanFunction, Kernel]
likelihood: Likelihood
jitter: float = static_field(1e-6)

def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution:
Expand Down Expand Up @@ -381,7 +414,7 @@ def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution:


@dataclass
class ConjugatePosterior(AbstractPosterior):
class ConjugatePosterior(AbstractPosterior[PriorType, GaussianLikelihood]):
r"""A Conjuate Gaussian process posterior object.
A Gaussian process posterior distribution when the constituent likelihood
Expand Down Expand Up @@ -600,7 +633,7 @@ def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:


@dataclass
class NonConjugatePosterior(AbstractPosterior):
class NonConjugatePosterior(AbstractPosterior[PriorType, NonGaussianLikelihood]):
r"""A non-conjugate Gaussian process posterior object.
A Gaussian process posterior object for models where the likelihood is
Expand Down Expand Up @@ -685,22 +718,17 @@ def predict(
#######################


@overload
def construct_posterior(prior: Prior, likelihood: Gaussian) -> ConjugatePosterior:
...


@overload
def construct_posterior(
prior: Prior, likelihood: NonGaussianLikelihood
) -> NonConjugatePosterior:
prior: PriorType, likelihood: GaussianLikelihood
) -> ConjugatePosterior[PriorType, GaussianLikelihood]:
...


@overload
def construct_posterior(
prior: Prior, likelihood: AbstractLikelihood
) -> AbstractPosterior:
prior: PriorType, likelihood: NonGaussianLikelihood
) -> NonConjugatePosterior[PriorType, NonGaussianLikelihood]:
...


Expand Down
4 changes: 2 additions & 2 deletions gpjax/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ def inv_probit(x: Float[Array, " *N"]) -> Float[Array, " *N"]:
return 0.5 * (1.0 + jsp.special.erf(x / jnp.sqrt(2.0))) * (1 - 2 * jitter) + jitter


NonGaussianLikelihood = Union[Poisson, Bernoulli]
NonGaussian = Union[Poisson, Bernoulli]

__all__ = [
"AbstractLikelihood",
"NonGaussianLikelihood",
"NonGaussian",
"Gaussian",
"Bernoulli",
"Poisson",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ fail_under = 50
precision = 1
show_missing = true
skip_covered = true
exclude_lines = ["if TYPE_CHECKING:"]

[tool.coverage.run] # https://coverage.readthedocs.io/en/latest/config.html#run
branch = true
Expand Down

0 comments on commit 6955957

Please sign in to comment.