Skip to content

Commit

Permalink
Add TFP NUTS and HMC samplers.
Browse files Browse the repository at this point in the history
Also clean up keyword argument handling.

PiperOrigin-RevId: 607003485
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Feb 14, 2024
1 parent 4ee0eb9 commit 0789573
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 37 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

## [0.1.8] - 2024-02-14

### Add HMC and NUTS from TFP
### Small change to blackjax default step size

## [0.1.7] - 2024-02-13

### Add SNAPER HMC from TFP
Expand Down
2 changes: 1 addition & 1 deletion bayeux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# A new PyPI release will be pushed everytime `__version__` is increased
# When changing this, also update the CHANGELOG.md
__version__ = '0.1.7'
__version__ = '0.1.8'

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
Expand Down
22 changes: 7 additions & 15 deletions bayeux/_src/mcmc/blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@


def get_extra_kwargs(kwargs):
return {
defaults = {
"chain_method": "vectorized",
"num_chains": 8,
"num_draws": 500,
"num_adapt_draws": 500,
"return_pytree": False,
} | kwargs
"return_pytree": False}
shared.update_with_kwargs(defaults, **kwargs)
return defaults


class _BlackjaxSampler(shared.Base):
Expand Down Expand Up @@ -299,7 +300,7 @@ def get_adaptation_kwargs(adaptation_algorithm, algorithm, log_density, kwargs):
run_kwargs["optim"] = optax.adam(learning_rate=0.01)
run_required.remove("optim")
if "step_size" in run_required:
run_kwargs["step_size"] = 0.001
run_kwargs["step_size"] = 0.5
run_required.remove("step_size")
run_kwargs["num_steps"] = kwargs.get("num_adapt_draws",
run_kwargs["num_steps"])
Expand All @@ -315,12 +316,8 @@ def get_algorithm_kwargs(algorithm, log_density, kwargs):
"step_size": 0.5,
"num_integration_steps": 16,
} | kwargs
algorithm_kwargs.update(
{
k: kwargs_with_defaults[k]
for k in algorithm_required
if k in kwargs_with_defaults
})
shared.update_with_kwargs(
algorithm_kwargs, algorithm_required, **kwargs_with_defaults)
algorithm_required.remove("logdensity_fn")
algorithm_required.discard("inverse_mass_matrix")
algorithm_required.discard("alpha")
Expand All @@ -332,11 +329,6 @@ def get_algorithm_kwargs(algorithm, log_density, kwargs):
raise ValueError(f"Unexpected required arguments: "
f"{','.join(algorithm_required)}. Probably file a bug, but"
" you can try to manually supply them as keywords.")
algorithm_kwargs.update(
{
k: kwargs_with_defaults[k]
for k in algorithm_kwargs
if k in kwargs_with_defaults})
return algorithm_kwargs


Expand Down
14 changes: 4 additions & 10 deletions bayeux/_src/mcmc/flowmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def get_nf_model_kwargs(nf_model, n_features, kwargs):

nf_model_kwargs, nf_model_required = shared.get_default_signature(
nf_model)
nf_model_kwargs.update(
{k: defaults[k] for k in nf_model_required if k in defaults})
shared.update_with_kwargs(nf_model_kwargs, nf_model_required, **defaults)
nf_model_required.remove("key")
nf_model_required.remove("kwargs")
nf_model_required = nf_model_required - nf_model_kwargs.keys()
Expand All @@ -74,9 +73,6 @@ def get_nf_model_kwargs(nf_model, n_features, kwargs):
f"{','.join(nf_model_required)}. Probably file a bug, but "
"you can try to manually supply them as keywords."
)
nf_model_kwargs.update(
{k: defaults[k] for k in nf_model_kwargs if k in defaults})

return nf_model_kwargs


Expand All @@ -94,9 +90,8 @@ def get_local_sampler_kwargs(local_sampler, log_density, n_features, kwargs):

sampler_kwargs, sampler_required = shared.get_default_signature(
local_sampler)
shared.update_with_kwargs(sampler_kwargs, sampler_required, **defaults)
sampler_kwargs.setdefault("jit", True)
sampler_kwargs.update(
{k: defaults[k] for k in sampler_required if k in defaults})
sampler_required = sampler_required - sampler_kwargs.keys()
if "params" in sampler_required:
sampler_kwargs["params"] = defaults
Expand Down Expand Up @@ -143,8 +138,7 @@ def get_sampler_kwargs(sampler, n_features, kwargs):
"n_dim": n_features,
"data": {}} | kwargs
sampler_kwargs, sampler_required = shared.get_default_signature(sampler)
sampler_kwargs.update(
{k: defaults[k] for k in sampler_required if k in defaults})
shared.update_with_kwargs(sampler_kwargs, sampler_required, **defaults)
sampler_required = (sampler_required -
{"nf_model", "local_sampler", "rng_key_set", "kwargs"})
sampler_required = sampler_required - sampler_kwargs.keys()
Expand All @@ -155,7 +149,7 @@ def get_sampler_kwargs(sampler, n_features, kwargs):
f"{','.join(sampler_required)}. Probably file a bug, but "
"you can try to manually supply them as keywords."
)
return defaults | sampler_kwargs
return sampler_kwargs


class _FlowMCSampler(shared.Base):
Expand Down
12 changes: 2 additions & 10 deletions bayeux/_src/mcmc/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class NUTS(_NumpyroSampler):
def get_sampler_kwargs(algorithm, kwargs):
"""Construct default args and include user arguments for samplers."""
sampler_kwargs, sampler_required = shared.get_default_signature(algorithm)
sampler_kwargs.update({k: kwargs[k] for k in sampler_required if k in kwargs})
shared.update_with_kwargs(sampler_kwargs, sampler_required, **kwargs)
sampler_kwargs.pop("potential_fn")

sampler_required = sampler_required - sampler_kwargs.keys()
Expand All @@ -84,7 +84,6 @@ def get_sampler_kwargs(algorithm, kwargs):
raise ValueError(f"Unexpected required arguments: "
f"{','.join(sampler_required)}. Probably file a bug, but "
"you can try to manually supply them as keywords.")
sampler_kwargs.update({k: kwargs[k] for k in sampler_kwargs if k in kwargs})
return sampler_kwargs


Expand All @@ -97,20 +96,13 @@ def get_mcmc_kwargs(kwargs):
"num_chains": 8,
"chain_method": "vectorized",
} | kwargs
mcmc_kwargs.update(
{k: kwargs_with_defaults[k] for k in mcmc_required
if k in kwargs_with_defaults})

shared.update_with_kwargs(mcmc_kwargs, mcmc_required, **kwargs_with_defaults)
mcmc_required = mcmc_required - mcmc_kwargs.keys()
mcmc_required.remove("sampler")

if mcmc_required:
raise ValueError(f"Unexpected required arguments: "
f"{','.join(mcmc_required)}. Probably file a bug, but "
"you can try to manually supply them as keywords.")
mcmc_kwargs.update(
{k: kwargs_with_defaults[k] for k in mcmc_kwargs
if k in kwargs_with_defaults})
return mcmc_kwargs


Expand Down
129 changes: 129 additions & 0 deletions bayeux/_src/mcmc/tfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,22 @@
import arviz as az
from bayeux._src import shared
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax as tfp


_ALGORITHMS = {
"hmc": tfp.mcmc.HamiltonianMonteCarlo,
"nuts": tfp.mcmc.NoUTurnSampler,
}

_TRACE_FNS = {
"hmc": tfp.experimental.mcmc.windowed_sampling.default_hmc_trace_fn,
"nuts": tfp.experimental.mcmc.windowed_sampling.default_nuts_trace_fn,
}


class SnaperHMC(shared.Base):
"""Implements SNAPER HMC [1] with step size adaptation.
Expand Down Expand Up @@ -104,6 +116,123 @@ def tlp(*args, **kwargs):
return az.from_dict(posterior=draws, sample_stats=_tfp_stats_to_dict(trace))


class _TFPBase(shared.Base):
"""Base class for TFP windowed samplers."""
name: str = ""
algorithm: str = ""

def get_kwargs(self, **kwargs):
if self.algorithm == "nuts":
target_accept_prob = 0.8
else:
target_accept_prob = 0.6

kwargs = {
"target_accept_prob": target_accept_prob,
"num_adaptation_steps": 500,
"step_size": 0.5,
"num_leapfrog_steps": 8,
} | kwargs
extra_parameters = {
"num_draws": 1_000,
"num_chains": 8,
"num_adaptation_steps": 500,
"return_pytree": False,
}
shared.update_with_kwargs(extra_parameters, **kwargs)

dual_averaging_kwargs, da_reqd = shared.get_default_signature(
tfp.mcmc.DualAveragingStepSizeAdaptation)
shared.update_with_kwargs(dual_averaging_kwargs, da_reqd, **kwargs)
da_reqd = da_reqd - dual_averaging_kwargs.keys()
da_reqd.remove("inner_kernel")
if da_reqd:
raise ValueError(
"Unexpected required arguments: "
f"{','.join(da_reqd)}. Probably file a bug, but "
"you can try to manually supply them as keywords."
)

proposal_kwargs, proposal_reqd = shared.get_default_signature(
_ALGORITHMS[self.algorithm])
shared.update_with_kwargs(proposal_kwargs, proposal_reqd, **kwargs)
proposal_reqd = proposal_reqd - proposal_kwargs.keys()
proposal_reqd.remove("target_log_prob_fn")
if proposal_reqd:
raise ValueError(
"Unexpected required arguments: "
f"{','.join(proposal_reqd)}. Probably file a bug, but "
"you can try to manually supply them as keywords."
)
return {
"extra_parameters": extra_parameters,
"dual_averaging_kwargs": dual_averaging_kwargs,
"proposal_kernel_kwargs": proposal_kwargs}

def __call__(self, seed, **kwargs):
kwargs = self.get_kwargs(**kwargs)
init_key, sample_key = jax.random.split(seed)

extra_parameters = kwargs["extra_parameters"]
dual_averaging_kwargs = kwargs["dual_averaging_kwargs"]
proposal_kernel_kwargs = kwargs["proposal_kernel_kwargs"]
initial_state = self.get_initial_state(
init_key, num_chains=extra_parameters["num_chains"])

vmapped_constrained_log_density = jax.vmap(self.constrained_log_density())
initial_transformed_position, treedef = jax.tree_util.tree_flatten(
self.inverse_transform_fn(initial_state))

def target_log_prob_fn(*args):
return vmapped_constrained_log_density(
jax.tree_util.tree_unflatten(treedef, args))
proposal_kernel_kwargs["target_log_prob_fn"] = target_log_prob_fn

initial_running_variance = [
tfp.experimental.stats.sample_stats.RunningVariance.from_stats(
num_samples=jnp.array(1, part.dtype),
mean=jnp.zeros_like(part),
variance=jnp.ones_like(part))
for part in initial_transformed_position]

draws, trace = tfp.experimental.mcmc.windowed_sampling._do_sampling(
kind=self.algorithm,
proposal_kernel_kwargs=proposal_kernel_kwargs,
dual_averaging_kwargs=dual_averaging_kwargs,
num_draws=extra_parameters["num_draws"],
num_burnin_steps=extra_parameters["num_adaptation_steps"],
initial_position=initial_transformed_position,
initial_running_variance=initial_running_variance,
bijector=None,
trace_fn=_TRACE_FNS[self.algorithm],
return_final_kernel_results=False,
chain_axis_names=None,
shard_axis_names=None,
seed=sample_key)

draws = self.transform_fn(jax.tree_util.tree_unflatten(treedef, draws))
if extra_parameters["return_pytree"]:
return draws

if hasattr(draws, "_asdict"):
draws = draws._asdict()
elif not isinstance(draws, dict):
draws = {"var0": draws}

draws = {x: np.swapaxes(v, 0, 1) for x, v in draws.items()}
return az.from_dict(posterior=draws, sample_stats=_tfp_stats_to_dict(trace))


class NUTS(_TFPBase):
name = "tfp_nuts"
algorithm = "nuts"


class HMC(_TFPBase):
name = "tfp_hmc"
algorithm = "hmc"


def _tfp_stats_to_dict(stats):
new_stats = {}
for k, v in stats.items():
Expand Down
8 changes: 8 additions & 0 deletions bayeux/_src/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def map_fn(chain_method, fn):
raise ValueError(f"Chain method {chain_method} not supported.")


def update_with_kwargs(defaults, reqd=None, **kwargs):
if reqd is None:
reqd = set()
defaults.update(
(k, kwargs[k]) for k in (defaults.keys() | reqd) & kwargs.keys())
return defaults


def _default_init(
*,
initial_state,
Expand Down
4 changes: 3 additions & 1 deletion bayeux/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import importlib

# TFP-on-JAX always installed
from bayeux._src.mcmc.tfp import HMC as HMC_TFP
from bayeux._src.mcmc.tfp import NUTS as NUTS_TFP
from bayeux._src.mcmc.tfp import SnaperHMC as SNAPER_HMC_TFP
__all__ = ["SNAPER_HMC_TFP"]
__all__ = ["HMC_TFP", "NUTS_TFP", "SNAPER_HMC_TFP"]

if importlib.util.find_spec("blackjax") is not None:
from bayeux._src.mcmc.blackjax import CheesHMC as CheesHMCblackjax
Expand Down
14 changes: 14 additions & 0 deletions bayeux/tests/mcmc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ def test_return_pytree_tfp():
assert pytree["x"]["y"].shape == (10, 4)


def test_return_pytree_tfp_nuts():
model = bx.Model(log_density=lambda pt: -pt["x"]["y"]**2,
test_point={"x": {"y": jnp.array(1.)}})
seed = jax.random.PRNGKey(0)
pytree = model.mcmc.tfp_nuts(
seed=seed,
return_pytree=True,
num_chains=4,
num_draws=10,
num_adaptation_steps=10,
)
assert pytree["x"]["y"].shape == (10, 4)


@pytest.mark.skipif(importlib.util.find_spec("flowMC") is None,
reason="Test requires flowMC which is not installed")
def test_return_pytree_flowmc():
Expand Down

0 comments on commit 0789573

Please sign in to comment.