Skip to content

Commit

Permalink
Add nutpie sampler to bayeux.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657599766
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Jul 30, 2024
1 parent 718368d commit e46e7f4
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 1 deletion.
2 changes: 1 addition & 1 deletion bayeux/_src/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def debug_no_ildj(

def _get_num_chains(default_kwargs):
for v in default_kwargs.values():
for key in ("num_chains", "num_particles", "batch_size"):
for key in ("num_chains", "num_particles", "batch_size", "chains"):
if key in v:
return v[key]
raise KeyError("No `num_chains` in default kwargs!")
Expand Down
135 changes: 135 additions & 0 deletions bayeux/_src/mcmc/nutpie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Nutpie specific code."""

import arviz as az
from bayeux._src import shared
import jax
import numpy as np

import nutpie
from nutpie.compiled_pyfunc import from_pyfunc


class _NutpieSampler(shared.Base):
"""Base class for nutpie sampler."""
name: str = "nutpie"

def _get_aux(self):
flat, unflatten = jax.flatten_util.ravel_pytree(self.test_point)

def flatten(pytree):
return jax.flatten_util.ravel_pytree(pytree)[0]

def make_logp_fn():
constrained_log_density = self.constrained_log_density()
def log_density(x):
return constrained_log_density(unflatten(x)).squeeze()
log_grad = jax.jit(jax.value_and_grad(log_density))
def wrapper(x):
val, grad = log_grad(x)
return val, np.array(grad, dtype=np.float64)
return wrapper
return make_logp_fn, flatten, unflatten, flat.shape[0]

def get_kwargs(self, **kwargs):
make_logp_fn, flatten, unflatten, ndim = self._get_aux()

def make_expand_fn(*args, **kwargs):
del args
del kwargs
return lambda x: {"x": np.asarray(x, dtype="float64")}

from_pyfunc_kwargs = {
"ndim": ndim,
"make_logp_fn": make_logp_fn,
"make_expand_fn": make_expand_fn,
"expanded_shapes": [(ndim,)],
"expanded_names": ["x"],
"expanded_dtypes": [np.float64],
}
from_pyfunc_kwargs = {
k: kwargs.get(k, v) for k, v in from_pyfunc_kwargs.items()}

kwargs_with_defaults = {
"draws": 1_000,
"chains": 8,
} | kwargs
sample_kwargs, _ = shared.get_default_signature(nutpie.sample)
sample_kwargs.update({k: kwargs_with_defaults[k] for k in sample_kwargs if
k in kwargs_with_defaults})
if "cores" not in kwargs:
sample_kwargs["cores"] = sample_kwargs["chains"]
extra_parameters = {"flatten": flatten,
"unflatten": unflatten,
"return_pytree": kwargs.get("return_pytree", False)}

return {from_pyfunc: from_pyfunc_kwargs,
nutpie.sample: sample_kwargs,
"extra_parameters": extra_parameters}

def __call__(self, seed, **kwargs):
kwargs = self.get_kwargs(**kwargs)
extra_parameters = kwargs["extra_parameters"]
compiled = from_pyfunc(**kwargs[from_pyfunc])
idata = nutpie.sample(compiled_model=compiled,
**kwargs[nutpie.sample])
return _postprocess_idata(idata,
extra_parameters["unflatten"],
self.transform_fn,
extra_parameters["return_pytree"])


def _pytree_to_dict(draws):
if hasattr(draws, "_asdict"):
draws = draws._asdict()
elif not isinstance(draws, dict):
draws = {"var0": draws}

return draws


def _postprocess_idata(idata, unflatten, transform_fn, return_pytree):
"""Convert nutpie inference data back to pytree, transform, and put back."""
unflatten = jax.vmap(jax.vmap(unflatten))
posterior = transform_fn(unflatten(idata.posterior.x.values))

if return_pytree:
return posterior

posterior = _pytree_to_dict(posterior)
warmup_posterior = _pytree_to_dict(
transform_fn(unflatten(idata.warmup_posterior.x.values)))
new_posterior = az.from_dict(posterior=posterior)
new_warmup_posterior = az.from_dict(posterior=warmup_posterior)
del idata.posterior
del idata.warmup_posterior
idata.add_groups(posterior=new_posterior.posterior,
warmup_posterior=new_warmup_posterior.posterior)
return idata
6 changes: 6 additions & 0 deletions bayeux/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,9 @@
from bayeux._src.mcmc.numpyro import NUTS as NUTSnumpyro

__all__.extend(["HMCnumpyro", "NUTSnumpyro"])

if importlib.util.find_spec("nutpie") is not None:
from bayeux._src.mcmc.nutpie import _NutpieSampler as NutpieSampler

__all__.extend(["NutpieSampler"])

17 changes: 17 additions & 0 deletions bayeux/tests/mcmc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,23 @@ def test_return_pytree_flowmc():
assert pytree["x"]["y"].shape == (4, 10, 2)


@pytest.mark.skipif(importlib.util.find_spec("nutpie") is None,
reason="Test requires nutpie which is not installed")
def test_return_pytree_nutpie():
model = bx.Model(log_density=lambda pt: -jnp.sum(pt["x"]["y"]**2),
test_point={"x": {"y": jnp.array([1., 1.])}})
seed = jax.random.PRNGKey(0)
pytree = model.mcmc.nutpie(
seed=seed,
return_pytree=True,
chains=4,
draws=10,
tune=10,
)
# 10 draws = (1 local + 1 global) * 5 loops
assert pytree["x"]["y"].shape == (4, 10, 2)


@pytest.mark.parametrize("method", METHODS)
def test_samplers(method):
# flowMC samplers are broken for 0 or 1 dimensions, so just test
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"numpyro",
"jaxopt",
"pymc",
"nutpie",
]

# `version` is automatically set by flit to use `bayeux.__version__`
Expand Down

0 comments on commit e46e7f4

Please sign in to comment.