Skip to content

Commit

Permalink
Add a .from_pymc adapter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597570748
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Jan 11, 2024
1 parent 64ef791 commit 59a9877
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 0 deletions.
79 changes: 79 additions & 0 deletions bayeux/_src/bayeux.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from bayeux import vi
from bayeux._src import shared
import jax
import jax.numpy as jnp
import oryx

_MODULES = (mcmc, optimize, vi)
Expand Down Expand Up @@ -151,3 +152,81 @@ def transform_fn(x):
transform_fn=transform_fn,
initial_state=initial_state)

@classmethod
def from_pymc(cls, pm_model, initial_state=None):
import pymc as pm # pylint: disable=g-import-not-at-top
import pymc.sampling.jax as pm_jax # pylint: disable=g-import-not-at-top

logp = pm_jax.get_jaxified_logp(
pm.model.transform.conditioning.remove_value_transforms(pm_model))

rvs = pm_model.free_RVs
values = pm_model.value_vars
names = [v.name for v in rvs]

def identity(x, *_):
return x

def none(*_):
return 0.

# We have different ideas of forward and backward!
fwd_transforms = {
k.name: identity if v is None else v.backward
for k, v in pm_model.rvs_to_transforms.items()}
bwd_transforms = {
k.name: identity if v is None else v.forward
for k, v in pm_model.rvs_to_transforms.items()}

def forward_transform(pt):
return [
fwd_transforms[k](v, *([] if v.owner is None else v.owner.inputs))
for k, v in zip(names, pt)]

def backward_transform(pt):
return [
bwd_transforms[k](v, *([] if v.owner is None else v.owner.inputs))
for k, v in zip(names, pt)
]

ildjs = {
k.name: none if v is None else v.log_jac_det
for k, v in pm_model.rvs_to_transforms.items()}

def ildj(pt):
return -pm.math.log(
pm.math.sum([
ildjs[k](v, *([] if v.owner is None else v.owner.inputs))
for k, v in zip(names, pt)]))

fwd = pm_jax.get_jaxified_graph(
inputs=values,
outputs=pm_model.replace_rvs_by_values(forward_transform(rvs)))
bwd = pm_jax.get_jaxified_graph(
inputs=values,
outputs=pm_model.replace_rvs_by_values(backward_transform(rvs)))
ildj = pm_jax.get_jaxified_graph(
inputs=values,
outputs=pm_model.replace_rvs_by_values([ildj(rvs)]))
def logp_wrap(args):
return logp([args[k] for k in names])

def fwd_wrap(args):
ret = fwd(*[args[k] for k in names])
return dict(zip(names, ret))

def bwd_wrap(args):
ret = bwd(*[args[k] for k in names])
return dict(zip(names, ret))

def ildj_wrap(args):
return ildj(*[args[k] for k in names])[0]

test_point = {rv.name: jnp.ones(rv.shape.eval()) for rv in rvs}
return cls(
log_density=logp_wrap,
test_point=test_point,
transform_fn=fwd_wrap,
inverse_transform_fn=bwd_wrap,
inverse_log_det_jacobian=ildj_wrap,
initial_state=initial_state)
35 changes: 35 additions & 0 deletions bayeux/tests/compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,38 @@ def tfp_model():
bx_model = bx.Model.from_tfp(pinned_model)
ret = bx_model.optimize.optax_adam(seed=jax.random.key(0))
assert ret is not None


def test_from_pymc():
import pymc as pm # pylint: disable=g-import-not-at-top

treatment_effects = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32)
treatment_stddevs = np.array(
[15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32)

with pm.Model() as model:
avg_effect = pm.Normal('avg_effect', 0., 10.)
avg_stddev = pm.HalfNormal('avg_stddev', 10.)
school_effects = pm.Normal('school_effects', shape=8)
pm.Normal('observed',
avg_effect + avg_stddev * school_effects,
treatment_stddevs,
observed=treatment_effects)

bx_model = bx.Model.from_pymc(model)
ret = bx_model.optimize.optax_adam(seed=jax.random.key(0))
assert ret is not None


def test_from_pymc_transforms():
import pymc as pm # pylint: disable=g-import-not-at-top

with pm.Model() as model:
pm.Normal('y')
lo = pm.Uniform('lo', lower=-1., upper=0.)
hi = pm.Uniform('hi', lower=0., upper=1.)
pm.Uniform('x', lower=lo, upper=hi)

bx_model = bx.Model.from_pymc(model)
ret = bx_model.optimize.optax_adam(seed=jax.random.key(0))
assert ret is not None
2 changes: 2 additions & 0 deletions bayeux/tests/optimize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def test_optimizers(method, linear_model): # pylint: disable=redefined-outer-na

if method not in {
"optax_adafactor",
"optax_adagrad",
"optax_sm3",
"optimistix_chord",
"optimistix_nelder_mead"}:
np.testing.assert_allclose(expected, params.w, atol=atol)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dev = [
"pytest-xdist",
"pylint>=2.6.0",
"pyink",
"pymc",
]


Expand Down

0 comments on commit 59a9877

Please sign in to comment.