diff --git a/gpjax/citation.py b/gpjax/citation.py index 74a7a7391..52fce6911 100644 --- a/gpjax/citation.py +++ b/gpjax/citation.py @@ -2,13 +2,13 @@ dataclass, fields, ) +from functools import singledispatch from beartype.typing import ( Dict, Union, ) from jaxlib.xla_extension import PjitFunction -from plum import dispatch from gpjax.kernels import ( RFF, @@ -26,8 +26,6 @@ NonConjugateMLL, ) -MaternKernels = Union[Matern12, Matern32, Matern52] -MLLs = Union[ConjugateMLL, NonConjugateMLL, LogPosteriorDensity] CitationType = Union[str, Dict[str, str]] @@ -89,24 +87,26 @@ class BookCitation(AbstractCitation): #################### # Default citation #################### -@dispatch -def cite(tree) -> NullCitation: +@singledispatch +def cite(tree) -> AbstractCitation: return NullCitation() #################### # Default citation #################### -@dispatch -def cite(tree: PjitFunction) -> JittedFnCitation: +@cite.register(PjitFunction) +def _(tree): return JittedFnCitation() #################### # Kernel citations #################### -@dispatch -def cite(tree: MaternKernels) -> PhDThesisCitation: +@cite.register(Matern12) +@cite.register(Matern32) +@cite.register(Matern52) +def _(tree) -> PhDThesisCitation: citation = PhDThesisCitation( citation_key="matern1960SpatialV", authors="Bertil Matérn", @@ -121,8 +121,8 @@ def cite(tree: MaternKernels) -> PhDThesisCitation: return citation -@dispatch -def cite(tree: ArcCosine) -> PaperCitation: +@cite.register(ArcCosine) +def _(_) -> PaperCitation: return PaperCitation( citation_key="cho2009kernel", authors="Cho, Youngmin and Saul, Lawrence", @@ -132,8 +132,8 @@ def cite(tree: ArcCosine) -> PaperCitation: ) -@dispatch -def cite(tree: GraphKernel) -> PaperCitation: +@cite.register(GraphKernel) +def _(tree) -> PaperCitation: return PaperCitation( citation_key="borovitskiy2021matern", title="Matérn Gaussian Processes on Graphs", @@ -146,8 +146,8 @@ def cite(tree: GraphKernel) -> PaperCitation: ) -@dispatch -def cite(tree: RFF) -> PaperCitation: +@cite.register(RFF) +def _(tree) -> PaperCitation: return PaperCitation( citation_key="rahimi2007random", authors="Rahimi, Ali and Recht, Benjamin", @@ -161,8 +161,10 @@ def cite(tree: RFF) -> PaperCitation: #################### # Objective citations #################### -@dispatch -def cite(tree: MLLs) -> BookCitation: +@cite.register(ConjugateMLL) +@cite.register(NonConjugateMLL) +@cite.register(LogPosteriorDensity) +def _(tree) -> BookCitation: return BookCitation( citation_key="rasmussen2006gaussian", title="Gaussian Processes for Machine Learning", @@ -173,8 +175,8 @@ def cite(tree: MLLs) -> BookCitation: ) -@dispatch -def cite(tree: CollapsedELBO) -> PaperCitation: +@cite.register(CollapsedELBO) +def _(tree) -> PaperCitation: return PaperCitation( citation_key="titsias2009variational", title="Variational learning of inducing variables in sparse Gaussian processes", @@ -184,8 +186,8 @@ def cite(tree: CollapsedELBO) -> PaperCitation: ) -@dispatch -def cite(tree: ELBO) -> PaperCitation: +@cite.register(ELBO) +def _(tree) -> PaperCitation: return PaperCitation( citation_key="hensman2013gaussian", title="Gaussian Processes for Big Data", diff --git a/poetry.lock b/poetry.lock index c8c8daf90..b386e35c6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5813,4 +5813,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "9b0aa2fdb17f57a3836ebe54d26eb4b1f5aa376a50688db810b3498a514b228a" +content-hash = "95048d017009d7fafa176580db39b5b348e1dded982777edc6771ceb8e0c6135" diff --git a/pyproject.toml b/pyproject.toml index 80f057f84..fe27586de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ jax = ">=0.4.10" jaxlib = ">=0.4.10" orbax-checkpoint = ">=0.2.3" cola-ml = "^0.0.1" -cola-plum-dispatch = "^0.1.1" [tool.poetry.group.test.dependencies] pytest = "^7.2.2" diff --git a/tests/conftest.py b/tests/conftest.py index e12a1f72d..451074f1f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ -from jaxtyping import install_import_hook +# from jaxtyping import install_import_hook -# import gpjax within import hook to apply beartype everywhere, before running tests -with install_import_hook("gpjax", "beartype.beartype"): - import gpjax # noqa: F401 +# # import gpjax within import hook to apply beartype everywhere, before running tests +# with install_import_hook("gpjax", "beartype.beartype"): +# import gpjax # noqa: F401