Skip to content

Commit

Permalink
Merge pull request #166 from Joshuaalbert/implement-jaxify
Browse files Browse the repository at this point in the history
* implement #59
  • Loading branch information
Joshuaalbert authored May 15, 2024
2 parents 9531bb4 + 7f8bff6 commit 2ae9ca3
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 7 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,11 @@ is the best way to achieve speed up.

# Change Log

22 Apr, 2024 -- JAXNS 2.4.13 released. Fixes bug where slice sampling not invariant to monotonic transforms of likelihod.
15 May, 2024 -- JAXNS 2.5.0 released. Added ability to handle non-JAX likelihoods, e.g. if you have a simulation
framework with python bindings you can now use it for likelihoods in JAXNS. Small performance improvements.

22 Apr, 2024 -- JAXNS 2.4.13 released. Fixes bug where slice sampling not invariant to monotonic transforms of
likelihod.

20 Mar, 2024 -- JAXNS 2.4.12 released. Minor bug fixes, and readability improvements. Added Empirial special prior.

Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
project = "jaxns"
copyright = "2022, Joshua G. Albert"
author = "Joshua G. Albert"
release = "2.4.13"
release = "2.5.0"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
1 change: 1 addition & 0 deletions jaxns/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from jaxns.framework.model import *
from jaxns.framework.prior import *
from jaxns.framework.special_priors import *
from jaxns.framework.jaxify import *
from jaxns.framework.bases import PriorModelGen, PriorModelType
44 changes: 44 additions & 0 deletions jaxns/framework/jaxify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import warnings
from typing import Callable

import jax
import numpy as np

from jaxns.internals.types import float_type, LikelihoodType

__all__ = [
'jaxify_likelihood'
]

def jaxify_likelihood(log_likelihood: Callable[..., np.ndarray], vectorised: bool = False) -> LikelihoodType:
"""
Wraps a non-JAX log likelihood function.
Args:
log_likelihood: a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar
log-likelihood.
vectorised: if True then the `log_likelihood` performs a vectorised computation for leading batch dimensions,
i.e. if a leading batch dimension is added to all input arguments, then it returns a vector of
log-likelihoods with the same leading batch dimension.
Returns:
A JAX-compatible log-likelihood function.
"""
warnings.warn(
"You're using a non-JAX log-likelihood function. This may be slower than a JAX log-likelihood function. "
"Also, you are responsible for ensuring that the function is deterministic. "
"Also, you cannot use learnable parameters in the likelihood call."
)

def _casted_log_likelihood(*args) -> np.ndarray:
return np.asarray(log_likelihood(*args), dtype=float_type)

def _log_likelihood(*args) -> jax.Array:
# Define the expected shape & dtype of output.
result_shape_dtype = jax.ShapeDtypeStruct(
shape=(),
dtype=float_type
)
return jax.pure_callback(_casted_log_likelihood, result_shape_dtype, *args, vectorized=vectorised)

return _log_likelihood
35 changes: 35 additions & 0 deletions jaxns/framework/tests/test_jaxify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import jax
import jax.random
import numpy as np

from jaxns import Prior, Model
from jaxns.framework.jaxify import jaxify_likelihood
from jaxns.framework.tests.test_model import tfpd


def test_jaxify_likelihood():
def log_likelihood(x, y):
return np.sum(x, axis=-1) + np.sum(y, axis=-1)

wrapped_ll = jaxify_likelihood(log_likelihood)
np.testing.assert_allclose(wrapped_ll(np.array([1, 2]), np.array([3, 4])), 10)

vmaped_wrapped_ll = jax.vmap(jaxify_likelihood(log_likelihood, vectorised=True))

np.testing.assert_allclose(vmaped_wrapped_ll(np.array([[1, 2], [2, 2]]), np.array([[3, 4], [4, 4]])),
np.array([10, 12]))


def test_jaxify():
def prior_model():
x = yield Prior(tfpd.Uniform(), name='x').parametrised()
return x

@jaxify_likelihood
def log_likelihood(x):
return x

model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
model.sanity_check(key=jax.random.PRNGKey(0), S=10)
assert model.U_ndims == 0
assert model.num_params == 1
5 changes: 1 addition & 4 deletions jaxns/framework/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def log_likelihood(obj: Obj):
model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
model.sanity_check(key=jax.random.PRNGKey(0), S=10)

def test_empty_prior_models():

def test_empty_prior_models():
def prior_model():
return 1.

Expand All @@ -148,6 +148,3 @@ def log_likelihood(x):

model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
model.sanity_check(key=jax.random.PRNGKey(0), S=10)



2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
long_description = fh.read()

setup(name='jaxns',
version='2.4.13',
version='2.5.0',
description='Nested Sampling in JAX',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 2ae9ca3

Please sign in to comment.