Skip to content

Commit

Permalink
Drop cola's special plum, and use single dispatch for citations.
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Aug 29, 2023
1 parent d3c2b56 commit 1c3f202
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 27 deletions.
44 changes: 23 additions & 21 deletions gpjax/citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
dataclass,
fields,
)
from functools import singledispatch

from beartype.typing import (
Dict,
Union,
)
from jaxlib.xla_extension import PjitFunction
from plum import dispatch

from gpjax.kernels import (
RFF,
Expand All @@ -26,8 +26,6 @@
NonConjugateMLL,
)

MaternKernels = Union[Matern12, Matern32, Matern52]
MLLs = Union[ConjugateMLL, NonConjugateMLL, LogPosteriorDensity]
CitationType = Union[str, Dict[str, str]]


Expand Down Expand Up @@ -89,24 +87,26 @@ class BookCitation(AbstractCitation):
####################
# Default citation
####################
@dispatch
def cite(tree) -> NullCitation:
@singledispatch
def cite(tree) -> AbstractCitation:
return NullCitation()


####################
# Default citation
####################
@dispatch
def cite(tree: PjitFunction) -> JittedFnCitation:
@cite.register(PjitFunction)
def _(tree):
return JittedFnCitation()


####################
# Kernel citations
####################
@dispatch
def cite(tree: MaternKernels) -> PhDThesisCitation:
@cite.register(Matern12)
@cite.register(Matern32)
@cite.register(Matern52)
def _(tree) -> PhDThesisCitation:
citation = PhDThesisCitation(
citation_key="matern1960SpatialV",
authors="Bertil Matérn",
Expand All @@ -121,8 +121,8 @@ def cite(tree: MaternKernels) -> PhDThesisCitation:
return citation


@dispatch
def cite(tree: ArcCosine) -> PaperCitation:
@cite.register(ArcCosine)
def _(_) -> PaperCitation:
return PaperCitation(
citation_key="cho2009kernel",
authors="Cho, Youngmin and Saul, Lawrence",
Expand All @@ -132,8 +132,8 @@ def cite(tree: ArcCosine) -> PaperCitation:
)


@dispatch
def cite(tree: GraphKernel) -> PaperCitation:
@cite.register(GraphKernel)
def _(tree) -> PaperCitation:
return PaperCitation(
citation_key="borovitskiy2021matern",
title="Matérn Gaussian Processes on Graphs",
Expand All @@ -146,8 +146,8 @@ def cite(tree: GraphKernel) -> PaperCitation:
)


@dispatch
def cite(tree: RFF) -> PaperCitation:
@cite.register(RFF)
def _(tree) -> PaperCitation:
return PaperCitation(
citation_key="rahimi2007random",
authors="Rahimi, Ali and Recht, Benjamin",
Expand All @@ -161,8 +161,10 @@ def cite(tree: RFF) -> PaperCitation:
####################
# Objective citations
####################
@dispatch
def cite(tree: MLLs) -> BookCitation:
@cite.register(ConjugateMLL)
@cite.register(NonConjugateMLL)
@cite.register(LogPosteriorDensity)
def _(tree) -> BookCitation:
return BookCitation(
citation_key="rasmussen2006gaussian",
title="Gaussian Processes for Machine Learning",
Expand All @@ -173,8 +175,8 @@ def cite(tree: MLLs) -> BookCitation:
)


@dispatch
def cite(tree: CollapsedELBO) -> PaperCitation:
@cite.register(CollapsedELBO)
def _(tree) -> PaperCitation:
return PaperCitation(
citation_key="titsias2009variational",
title="Variational learning of inducing variables in sparse Gaussian processes",
Expand All @@ -184,8 +186,8 @@ def cite(tree: CollapsedELBO) -> PaperCitation:
)


@dispatch
def cite(tree: ELBO) -> PaperCitation:
@cite.register(ELBO)
def _(tree) -> PaperCitation:
return PaperCitation(
citation_key="hensman2013gaussian",
title="Gaussian Processes for Big Data",
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jax = ">=0.4.10"
jaxlib = ">=0.4.10"
orbax-checkpoint = ">=0.2.3"
cola-ml = "^0.0.1"
cola-plum-dispatch = "^0.1.1"

[tool.poetry.group.test.dependencies]
pytest = "^7.2.2"
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from jaxtyping import install_import_hook
# from jaxtyping import install_import_hook

# import gpjax within import hook to apply beartype everywhere, before running tests
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax # noqa: F401
# # import gpjax within import hook to apply beartype everywhere, before running tests
# with install_import_hook("gpjax", "beartype.beartype"):
# import gpjax # noqa: F401

0 comments on commit 1c3f202

Please sign in to comment.