diff --git a/bayeux/_src/bayeux.py b/bayeux/_src/bayeux.py index c579236..6d8dc78 100644 --- a/bayeux/_src/bayeux.py +++ b/bayeux/_src/bayeux.py @@ -21,7 +21,6 @@ from bayeux import vi from bayeux._src import shared import jax -import jax.numpy as jnp import oryx _MODULES = (mcmc, optimize, vi) @@ -157,76 +156,26 @@ def from_pymc(cls, pm_model, initial_state=None): import pymc as pm # pylint: disable=g-import-not-at-top import pymc.sampling.jax as pm_jax # pylint: disable=g-import-not-at-top - logp = pm_jax.get_jaxified_logp( - pm.model.transform.conditioning.remove_value_transforms(pm_model)) + uc_model = pm.model.transform.conditioning.remove_value_transforms(pm_model) + logp = pm_jax.get_jaxified_logp(uc_model) rvs = pm_model.free_RVs values = pm_model.value_vars names = [v.name for v in rvs] - def identity(x, *_): - return x - - def none(*_): - return 0. - - # We have different ideas of forward and backward! - fwd_transforms = { - k.name: identity if v is None else v.backward - for k, v in pm_model.rvs_to_transforms.items()} - bwd_transforms = { - k.name: identity if v is None else v.forward - for k, v in pm_model.rvs_to_transforms.items()} - - def forward_transform(pt): - return [ - fwd_transforms[k](v, *([] if v.owner is None else v.owner.inputs)) - for k, v in zip(names, pt)] - - def backward_transform(pt): - return [ - bwd_transforms[k](v, *([] if v.owner is None else v.owner.inputs)) - for k, v in zip(names, pt) - ] - - ildjs = { - k.name: none if v is None else v.log_jac_det - for k, v in pm_model.rvs_to_transforms.items()} - - def ildj(pt): - return -pm.math.log( - pm.math.sum([ - ildjs[k](v, *([] if v.owner is None else v.owner.inputs)) - for k, v in zip(names, pt)])) - fwd = pm_jax.get_jaxified_graph( inputs=values, - outputs=pm_model.replace_rvs_by_values(forward_transform(rvs))) - bwd = pm_jax.get_jaxified_graph( - inputs=values, - outputs=pm_model.replace_rvs_by_values(backward_transform(rvs))) - ildj = pm_jax.get_jaxified_graph( - inputs=values, - outputs=pm_model.replace_rvs_by_values([ildj(rvs)])) + outputs=pm_model.replace_rvs_by_values(rvs)) + def logp_wrap(args): return logp([args[k] for k in names]) def fwd_wrap(args): ret = fwd(*[args[k] for k in names]) return dict(zip(names, ret)) - - def bwd_wrap(args): - ret = bwd(*[args[k] for k in names]) - return dict(zip(names, ret)) - - def ildj_wrap(args): - return ildj(*[args[k] for k in names])[0] - - test_point = {rv.name: jnp.ones(rv.shape.eval()) for rv in rvs} + test_point = uc_model.initial_point() return cls( log_density=logp_wrap, test_point=test_point, transform_fn=fwd_wrap, - inverse_transform_fn=bwd_wrap, - inverse_log_det_jacobian=ildj_wrap, initial_state=initial_state) diff --git a/bayeux/tests/compat_test.py b/bayeux/tests/compat_test.py index cb3b687..5846808 100644 --- a/bayeux/tests/compat_test.py +++ b/bayeux/tests/compat_test.py @@ -42,7 +42,9 @@ def numpyro_model(): bx_model = bx.Model.from_numpyro(numpyro_model) ret = bx_model.optimize.optax_adam(seed=jax.random.key(0)) - assert ret is not None + log_probs = jax.vmap(bx_model.log_density)(ret.params) + expected = np.full(8, -42.56270289) + np.testing.assert_allclose(log_probs, expected, atol=1e-3) def test_from_tfp(): @@ -64,7 +66,9 @@ def tfp_model(): pinned_model = tfp_model.experimental_pin(observed=treatment_effects) bx_model = bx.Model.from_tfp(pinned_model) ret = bx_model.optimize.optax_adam(seed=jax.random.key(0)) - assert ret is not None + log_probs = jax.vmap(bx_model.log_density)(ret.params) + expected = np.full(8, -42.56270289) + np.testing.assert_allclose(log_probs, expected, atol=1e-3) def test_from_pymc(): @@ -85,7 +89,9 @@ def test_from_pymc(): bx_model = bx.Model.from_pymc(model) ret = bx_model.optimize.optax_adam(seed=jax.random.key(0)) - assert ret is not None + log_probs = jax.vmap(bx_model.log_density)(ret.params) + expected = np.full(8, -42.56270289) + np.testing.assert_allclose(log_probs, expected, atol=1e-3) def test_from_pymc_transforms():