Skip to content

Commit

Permalink
Fix pymc compatibility, add more meaningful tests.
Browse files Browse the repository at this point in the history
In my tests, `oryx` does a fine job at computing inverses and ildjs, so we just compute the forward transform with pymc. Additionally, we were double transforming variables earlier.

PiperOrigin-RevId: 597922177
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Jan 12, 2024
1 parent 9388384 commit 66753aa
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 50 deletions.
71 changes: 24 additions & 47 deletions bayeux/_src/bayeux.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from bayeux import vi
from bayeux._src import shared
import jax
import jax.numpy as jnp
import oryx

_MODULES = (mcmc, optimize, vi)
Expand Down Expand Up @@ -157,57 +156,39 @@ def from_pymc(cls, pm_model, initial_state=None):
import pymc as pm # pylint: disable=g-import-not-at-top
import pymc.sampling.jax as pm_jax # pylint: disable=g-import-not-at-top

logp = pm_jax.get_jaxified_logp(
pm.model.transform.conditioning.remove_value_transforms(pm_model))
class Inverse(pm.logprob.transforms.Transform):
def __init__(self, transform):
self._transform = transform

rvs = pm_model.free_RVs
values = pm_model.value_vars
names = [v.name for v in rvs]

def identity(x, *_):
return x

def none(*_):
return 0.

# We have different ideas of forward and backward!
fwd_transforms = {
k.name: identity if v is None else v.backward
for k, v in pm_model.rvs_to_transforms.items()}
bwd_transforms = {
k.name: identity if v is None else v.forward
for k, v in pm_model.rvs_to_transforms.items()}
def forward(self, value, *inputs):
"""Apply the transformation."""
return self._transform.backward(value, *inputs)

def forward_transform(pt):
return [
fwd_transforms[k](v, *([] if v.owner is None else v.owner.inputs))
for k, v in zip(names, pt)]
def backward(self, value, *inputs):
return self._transform.forward(value, *inputs)

def backward_transform(pt):
return [
bwd_transforms[k](v, *([] if v.owner is None else v.owner.inputs))
for k, v in zip(names, pt)
]
uc_model = pm.model.transform.conditioning.remove_value_transforms(pm_model)
logp = pm_jax.get_jaxified_logp(uc_model)

ildjs = {
k.name: none if v is None else v.log_jac_det
rvs_to_inverse = {
k: None if v is None else Inverse(v)
for k, v in pm_model.rvs_to_transforms.items()}

def ildj(pt):
return -pm.math.log(
pm.math.sum([
ildjs[k](v, *([] if v.owner is None else v.owner.inputs))
for k, v in zip(names, pt)]))
rvs = pm_model.free_RVs
inv_rvs = pm.logprob.utils.replace_rvs_by_values(
rvs,
rvs_to_values=pm_model.rvs_to_values,
rvs_to_transforms=rvs_to_inverse)
values = pm_model.value_vars
names = [v.name for v in rvs]

fwd = pm_jax.get_jaxified_graph(
inputs=values,
outputs=pm_model.replace_rvs_by_values(forward_transform(rvs)))
outputs=pm_model.replace_rvs_by_values(rvs))

bwd = pm_jax.get_jaxified_graph(
inputs=values,
outputs=pm_model.replace_rvs_by_values(backward_transform(rvs)))
ildj = pm_jax.get_jaxified_graph(
inputs=values,
outputs=pm_model.replace_rvs_by_values([ildj(rvs)]))
outputs=pm_model.replace_rvs_by_values(inv_rvs))

def logp_wrap(args):
return logp([args[k] for k in names])

Expand All @@ -219,14 +200,10 @@ def bwd_wrap(args):
ret = bwd(*[args[k] for k in names])
return dict(zip(names, ret))

def ildj_wrap(args):
return ildj(*[args[k] for k in names])[0]

test_point = {rv.name: jnp.ones(rv.shape.eval()) for rv in rvs}
test_point = uc_model.initial_point()
return cls(
log_density=logp_wrap,
test_point=test_point,
transform_fn=fwd_wrap,
inverse_transform_fn=bwd_wrap,
inverse_log_det_jacobian=ildj_wrap,
initial_state=initial_state)
12 changes: 9 additions & 3 deletions bayeux/tests/compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def numpyro_model():

bx_model = bx.Model.from_numpyro(numpyro_model)
ret = bx_model.optimize.optax_adam(seed=jax.random.key(0))
assert ret is not None
log_probs = jax.vmap(bx_model.log_density)(ret.params)
expected = np.full(8, -42.56270289)
np.testing.assert_allclose(log_probs, expected, atol=1e-3)


def test_from_tfp():
Expand All @@ -64,7 +66,9 @@ def tfp_model():
pinned_model = tfp_model.experimental_pin(observed=treatment_effects)
bx_model = bx.Model.from_tfp(pinned_model)
ret = bx_model.optimize.optax_adam(seed=jax.random.key(0))
assert ret is not None
log_probs = jax.vmap(bx_model.log_density)(ret.params)
expected = np.full(8, -42.56270289)
np.testing.assert_allclose(log_probs, expected, atol=1e-3)


def test_from_pymc():
Expand All @@ -85,7 +89,9 @@ def test_from_pymc():

bx_model = bx.Model.from_pymc(model)
ret = bx_model.optimize.optax_adam(seed=jax.random.key(0))
assert ret is not None
log_probs = jax.vmap(bx_model.log_density)(ret.params)
expected = np.full(8, -42.56270289)
np.testing.assert_allclose(log_probs, expected, atol=1e-3)


def test_from_pymc_transforms():
Expand Down
1 change: 1 addition & 0 deletions bayeux/tests/optimize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_optimizers(method, linear_model): # pylint: disable=redefined-outer-na
"optax_adafactor",
"optax_adagrad",
"optax_sm3",
"optimistix_bfgs",
"optimistix_chord",
"optimistix_nelder_mead"}:
np.testing.assert_allclose(expected, params.w, atol=atol)
Expand Down

0 comments on commit 66753aa

Please sign in to comment.