From 7f8bff695cc29073cc4e5f3208511e53eabf821e Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 15 May 2024 11:57:15 +0200 Subject: [PATCH] * implement #59 * bump to 2.5.0 --- README.md | 6 +++- docs/conf.py | 2 +- jaxns/framework/__init__.py | 1 + jaxns/framework/jaxify.py | 44 ++++++++++++++++++++++++++++ jaxns/framework/tests/test_jaxify.py | 35 ++++++++++++++++++++++ jaxns/framework/tests/test_model.py | 5 +--- setup.py | 2 +- 7 files changed, 88 insertions(+), 7 deletions(-) create mode 100644 jaxns/framework/jaxify.py create mode 100644 jaxns/framework/tests/test_jaxify.py diff --git a/README.md b/README.md index e3eebeb9..f5929a4e 100644 --- a/README.md +++ b/README.md @@ -363,7 +363,11 @@ is the best way to achieve speed up. # Change Log -22 Apr, 2024 -- JAXNS 2.4.13 released. Fixes bug where slice sampling not invariant to monotonic transforms of likelihod. +15 May, 2024 -- JAXNS 2.5.0 released. Added ability to handle non-JAX likelihoods, e.g. if you have a simulation +framework with python bindings you can now use it for likelihoods in JAXNS. Small performance improvements. + +22 Apr, 2024 -- JAXNS 2.4.13 released. Fixes bug where slice sampling not invariant to monotonic transforms of +likelihod. 20 Mar, 2024 -- JAXNS 2.4.12 released. Minor bug fixes, and readability improvements. Added Empirial special prior. diff --git a/docs/conf.py b/docs/conf.py index e481d7e8..5fdd9ef5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,7 +12,7 @@ project = "jaxns" copyright = "2022, Joshua G. Albert" author = "Joshua G. Albert" -release = "2.4.13" +release = "2.5.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/jaxns/framework/__init__.py b/jaxns/framework/__init__.py index 91bc7c35..184e7793 100644 --- a/jaxns/framework/__init__.py +++ b/jaxns/framework/__init__.py @@ -1,4 +1,5 @@ from jaxns.framework.model import * from jaxns.framework.prior import * from jaxns.framework.special_priors import * +from jaxns.framework.jaxify import * from jaxns.framework.bases import PriorModelGen, PriorModelType \ No newline at end of file diff --git a/jaxns/framework/jaxify.py b/jaxns/framework/jaxify.py new file mode 100644 index 00000000..6f223244 --- /dev/null +++ b/jaxns/framework/jaxify.py @@ -0,0 +1,44 @@ +import warnings +from typing import Callable + +import jax +import numpy as np + +from jaxns.internals.types import float_type, LikelihoodType + +__all__ = [ + 'jaxify_likelihood' +] + +def jaxify_likelihood(log_likelihood: Callable[..., np.ndarray], vectorised: bool = False) -> LikelihoodType: + """ + Wraps a non-JAX log likelihood function. + + Args: + log_likelihood: a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar + log-likelihood. + vectorised: if True then the `log_likelihood` performs a vectorised computation for leading batch dimensions, + i.e. if a leading batch dimension is added to all input arguments, then it returns a vector of + log-likelihoods with the same leading batch dimension. + + Returns: + A JAX-compatible log-likelihood function. + """ + warnings.warn( + "You're using a non-JAX log-likelihood function. This may be slower than a JAX log-likelihood function. " + "Also, you are responsible for ensuring that the function is deterministic. " + "Also, you cannot use learnable parameters in the likelihood call." + ) + + def _casted_log_likelihood(*args) -> np.ndarray: + return np.asarray(log_likelihood(*args), dtype=float_type) + + def _log_likelihood(*args) -> jax.Array: + # Define the expected shape & dtype of output. + result_shape_dtype = jax.ShapeDtypeStruct( + shape=(), + dtype=float_type + ) + return jax.pure_callback(_casted_log_likelihood, result_shape_dtype, *args, vectorized=vectorised) + + return _log_likelihood diff --git a/jaxns/framework/tests/test_jaxify.py b/jaxns/framework/tests/test_jaxify.py new file mode 100644 index 00000000..b124c5f7 --- /dev/null +++ b/jaxns/framework/tests/test_jaxify.py @@ -0,0 +1,35 @@ +import jax +import jax.random +import numpy as np + +from jaxns import Prior, Model +from jaxns.framework.jaxify import jaxify_likelihood +from jaxns.framework.tests.test_model import tfpd + + +def test_jaxify_likelihood(): + def log_likelihood(x, y): + return np.sum(x, axis=-1) + np.sum(y, axis=-1) + + wrapped_ll = jaxify_likelihood(log_likelihood) + np.testing.assert_allclose(wrapped_ll(np.array([1, 2]), np.array([3, 4])), 10) + + vmaped_wrapped_ll = jax.vmap(jaxify_likelihood(log_likelihood, vectorised=True)) + + np.testing.assert_allclose(vmaped_wrapped_ll(np.array([[1, 2], [2, 2]]), np.array([[3, 4], [4, 4]])), + np.array([10, 12])) + + +def test_jaxify(): + def prior_model(): + x = yield Prior(tfpd.Uniform(), name='x').parametrised() + return x + + @jaxify_likelihood + def log_likelihood(x): + return x + + model = Model(prior_model=prior_model, log_likelihood=log_likelihood) + model.sanity_check(key=jax.random.PRNGKey(0), S=10) + assert model.U_ndims == 0 + assert model.num_params == 1 diff --git a/jaxns/framework/tests/test_model.py b/jaxns/framework/tests/test_model.py index 84eb468a..22dcce23 100644 --- a/jaxns/framework/tests/test_model.py +++ b/jaxns/framework/tests/test_model.py @@ -138,8 +138,8 @@ def log_likelihood(obj: Obj): model = Model(prior_model=prior_model, log_likelihood=log_likelihood) model.sanity_check(key=jax.random.PRNGKey(0), S=10) -def test_empty_prior_models(): +def test_empty_prior_models(): def prior_model(): return 1. @@ -148,6 +148,3 @@ def log_likelihood(x): model = Model(prior_model=prior_model, log_likelihood=log_likelihood) model.sanity_check(key=jax.random.PRNGKey(0), S=10) - - - diff --git a/setup.py b/setup.py index f329dde1..2b72fc0d 100755 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ long_description = fh.read() setup(name='jaxns', - version='2.4.13', + version='2.5.0', description='Nested Sampling in JAX', long_description=long_description, long_description_content_type="text/markdown",