Skip to content

Commit

Permalink
Fix pymc compatibility, add more meaningful tests.
Browse files Browse the repository at this point in the history
In my tests, `oryx` does a fine job at computing inverses and ildjs, so we just compute the forward transform with pymc. Additionally, we were double transforming variables earlier.

PiperOrigin-RevId: 597840016
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Jan 12, 2024
1 parent 9388384 commit 1f48893
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 59 deletions.
61 changes: 5 additions & 56 deletions bayeux/_src/bayeux.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from bayeux import vi
from bayeux._src import shared
import jax
import jax.numpy as jnp
import oryx

_MODULES = (mcmc, optimize, vi)
Expand Down Expand Up @@ -157,76 +156,26 @@ def from_pymc(cls, pm_model, initial_state=None):
import pymc as pm # pylint: disable=g-import-not-at-top
import pymc.sampling.jax as pm_jax # pylint: disable=g-import-not-at-top

logp = pm_jax.get_jaxified_logp(
pm.model.transform.conditioning.remove_value_transforms(pm_model))
uc_model = pm.model.transform.conditioning.remove_value_transforms(pm_model)
logp = pm_jax.get_jaxified_logp(uc_model)

rvs = pm_model.free_RVs
values = pm_model.value_vars
names = [v.name for v in rvs]

def identity(x, *_):
return x

def none(*_):
return 0.

# We have different ideas of forward and backward!
fwd_transforms = {
k.name: identity if v is None else v.backward
for k, v in pm_model.rvs_to_transforms.items()}
bwd_transforms = {
k.name: identity if v is None else v.forward
for k, v in pm_model.rvs_to_transforms.items()}

def forward_transform(pt):
return [
fwd_transforms[k](v, *([] if v.owner is None else v.owner.inputs))
for k, v in zip(names, pt)]

def backward_transform(pt):
return [
bwd_transforms[k](v, *([] if v.owner is None else v.owner.inputs))
for k, v in zip(names, pt)
]

ildjs = {
k.name: none if v is None else v.log_jac_det
for k, v in pm_model.rvs_to_transforms.items()}

def ildj(pt):
return -pm.math.log(
pm.math.sum([
ildjs[k](v, *([] if v.owner is None else v.owner.inputs))
for k, v in zip(names, pt)]))

fwd = pm_jax.get_jaxified_graph(
inputs=values,
outputs=pm_model.replace_rvs_by_values(forward_transform(rvs)))
bwd = pm_jax.get_jaxified_graph(
inputs=values,
outputs=pm_model.replace_rvs_by_values(backward_transform(rvs)))
ildj = pm_jax.get_jaxified_graph(
inputs=values,
outputs=pm_model.replace_rvs_by_values([ildj(rvs)]))
outputs=pm_model.replace_rvs_by_values(rvs))

def logp_wrap(args):
return logp([args[k] for k in names])

def fwd_wrap(args):
ret = fwd(*[args[k] for k in names])
return dict(zip(names, ret))

def bwd_wrap(args):
ret = bwd(*[args[k] for k in names])
return dict(zip(names, ret))

def ildj_wrap(args):
return ildj(*[args[k] for k in names])[0]

test_point = {rv.name: jnp.ones(rv.shape.eval()) for rv in rvs}
test_point = uc_model.initial_point()
return cls(
log_density=logp_wrap,
test_point=test_point,
transform_fn=fwd_wrap,
inverse_transform_fn=bwd_wrap,
inverse_log_det_jacobian=ildj_wrap,
initial_state=initial_state)
12 changes: 9 additions & 3 deletions bayeux/tests/compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def numpyro_model():

bx_model = bx.Model.from_numpyro(numpyro_model)
ret = bx_model.optimize.optax_adam(seed=jax.random.key(0))
assert ret is not None
log_probs = jax.vmap(bx_model.log_density)(ret.params)
expected = np.full(8, -42.56270289)
np.testing.assert_allclose(log_probs, expected, atol=1e-3)


def test_from_tfp():
Expand All @@ -64,7 +66,9 @@ def tfp_model():
pinned_model = tfp_model.experimental_pin(observed=treatment_effects)
bx_model = bx.Model.from_tfp(pinned_model)
ret = bx_model.optimize.optax_adam(seed=jax.random.key(0))
assert ret is not None
log_probs = jax.vmap(bx_model.log_density)(ret.params)
expected = np.full(8, -42.56270289)
np.testing.assert_allclose(log_probs, expected, atol=1e-3)


def test_from_pymc():
Expand All @@ -85,7 +89,9 @@ def test_from_pymc():

bx_model = bx.Model.from_pymc(model)
ret = bx_model.optimize.optax_adam(seed=jax.random.key(0))
assert ret is not None
log_probs = jax.vmap(bx_model.log_density)(ret.params)
expected = np.full(8, -42.56270289)
np.testing.assert_allclose(log_probs, expected, atol=1e-3)


def test_from_pymc_transforms():
Expand Down

0 comments on commit 1f48893

Please sign in to comment.