From 66753aa72b0ec14c0e155847a4e0f12f08ce31ad Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Fri, 12 Jan 2024 12:38:43 -0800 Subject: [PATCH] Fix pymc compatibility, add more meaningful tests. In my tests, `oryx` does a fine job at computing inverses and ildjs, so we just compute the forward transform with pymc. Additionally, we were double transforming variables earlier. PiperOrigin-RevId: 597922177 --- bayeux/_src/bayeux.py | 71 ++++++++++++----------------------- bayeux/tests/compat_test.py | 12 ++++-- bayeux/tests/optimize_test.py | 1 + 3 files changed, 34 insertions(+), 50 deletions(-) diff --git a/bayeux/_src/bayeux.py b/bayeux/_src/bayeux.py index c579236..3095a05 100644 --- a/bayeux/_src/bayeux.py +++ b/bayeux/_src/bayeux.py @@ -21,7 +21,6 @@ from bayeux import vi from bayeux._src import shared import jax -import jax.numpy as jnp import oryx _MODULES = (mcmc, optimize, vi) @@ -157,57 +156,39 @@ def from_pymc(cls, pm_model, initial_state=None): import pymc as pm # pylint: disable=g-import-not-at-top import pymc.sampling.jax as pm_jax # pylint: disable=g-import-not-at-top - logp = pm_jax.get_jaxified_logp( - pm.model.transform.conditioning.remove_value_transforms(pm_model)) + class Inverse(pm.logprob.transforms.Transform): + def __init__(self, transform): + self._transform = transform - rvs = pm_model.free_RVs - values = pm_model.value_vars - names = [v.name for v in rvs] - - def identity(x, *_): - return x - - def none(*_): - return 0. - - # We have different ideas of forward and backward! - fwd_transforms = { - k.name: identity if v is None else v.backward - for k, v in pm_model.rvs_to_transforms.items()} - bwd_transforms = { - k.name: identity if v is None else v.forward - for k, v in pm_model.rvs_to_transforms.items()} + def forward(self, value, *inputs): + """Apply the transformation.""" + return self._transform.backward(value, *inputs) - def forward_transform(pt): - return [ - fwd_transforms[k](v, *([] if v.owner is None else v.owner.inputs)) - for k, v in zip(names, pt)] + def backward(self, value, *inputs): + return self._transform.forward(value, *inputs) - def backward_transform(pt): - return [ - bwd_transforms[k](v, *([] if v.owner is None else v.owner.inputs)) - for k, v in zip(names, pt) - ] + uc_model = pm.model.transform.conditioning.remove_value_transforms(pm_model) + logp = pm_jax.get_jaxified_logp(uc_model) - ildjs = { - k.name: none if v is None else v.log_jac_det + rvs_to_inverse = { + k: None if v is None else Inverse(v) for k, v in pm_model.rvs_to_transforms.items()} - - def ildj(pt): - return -pm.math.log( - pm.math.sum([ - ildjs[k](v, *([] if v.owner is None else v.owner.inputs)) - for k, v in zip(names, pt)])) + rvs = pm_model.free_RVs + inv_rvs = pm.logprob.utils.replace_rvs_by_values( + rvs, + rvs_to_values=pm_model.rvs_to_values, + rvs_to_transforms=rvs_to_inverse) + values = pm_model.value_vars + names = [v.name for v in rvs] fwd = pm_jax.get_jaxified_graph( inputs=values, - outputs=pm_model.replace_rvs_by_values(forward_transform(rvs))) + outputs=pm_model.replace_rvs_by_values(rvs)) + bwd = pm_jax.get_jaxified_graph( inputs=values, - outputs=pm_model.replace_rvs_by_values(backward_transform(rvs))) - ildj = pm_jax.get_jaxified_graph( - inputs=values, - outputs=pm_model.replace_rvs_by_values([ildj(rvs)])) + outputs=pm_model.replace_rvs_by_values(inv_rvs)) + def logp_wrap(args): return logp([args[k] for k in names]) @@ -219,14 +200,10 @@ def bwd_wrap(args): ret = bwd(*[args[k] for k in names]) return dict(zip(names, ret)) - def ildj_wrap(args): - return ildj(*[args[k] for k in names])[0] - - test_point = {rv.name: jnp.ones(rv.shape.eval()) for rv in rvs} + test_point = uc_model.initial_point() return cls( log_density=logp_wrap, test_point=test_point, transform_fn=fwd_wrap, inverse_transform_fn=bwd_wrap, - inverse_log_det_jacobian=ildj_wrap, initial_state=initial_state) diff --git a/bayeux/tests/compat_test.py b/bayeux/tests/compat_test.py index cb3b687..5846808 100644 --- a/bayeux/tests/compat_test.py +++ b/bayeux/tests/compat_test.py @@ -42,7 +42,9 @@ def numpyro_model(): bx_model = bx.Model.from_numpyro(numpyro_model) ret = bx_model.optimize.optax_adam(seed=jax.random.key(0)) - assert ret is not None + log_probs = jax.vmap(bx_model.log_density)(ret.params) + expected = np.full(8, -42.56270289) + np.testing.assert_allclose(log_probs, expected, atol=1e-3) def test_from_tfp(): @@ -64,7 +66,9 @@ def tfp_model(): pinned_model = tfp_model.experimental_pin(observed=treatment_effects) bx_model = bx.Model.from_tfp(pinned_model) ret = bx_model.optimize.optax_adam(seed=jax.random.key(0)) - assert ret is not None + log_probs = jax.vmap(bx_model.log_density)(ret.params) + expected = np.full(8, -42.56270289) + np.testing.assert_allclose(log_probs, expected, atol=1e-3) def test_from_pymc(): @@ -85,7 +89,9 @@ def test_from_pymc(): bx_model = bx.Model.from_pymc(model) ret = bx_model.optimize.optax_adam(seed=jax.random.key(0)) - assert ret is not None + log_probs = jax.vmap(bx_model.log_density)(ret.params) + expected = np.full(8, -42.56270289) + np.testing.assert_allclose(log_probs, expected, atol=1e-3) def test_from_pymc_transforms(): diff --git a/bayeux/tests/optimize_test.py b/bayeux/tests/optimize_test.py index 7ca34aa..a967279 100644 --- a/bayeux/tests/optimize_test.py +++ b/bayeux/tests/optimize_test.py @@ -88,6 +88,7 @@ def test_optimizers(method, linear_model): # pylint: disable=redefined-outer-na "optax_adafactor", "optax_adagrad", "optax_sm3", + "optimistix_bfgs", "optimistix_chord", "optimistix_nelder_mead"}: np.testing.assert_allclose(expected, params.w, atol=atol)