From 078957364c761d1f6ca7a6199b819e9442b78dbf Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Wed, 14 Feb 2024 09:01:50 -0800 Subject: [PATCH] Add TFP NUTS and HMC samplers. Also clean up keyword argument handling. PiperOrigin-RevId: 607003485 --- CHANGELOG.md | 5 ++ bayeux/__init__.py | 2 +- bayeux/_src/mcmc/blackjax.py | 22 ++---- bayeux/_src/mcmc/flowmc.py | 14 ++-- bayeux/_src/mcmc/numpyro.py | 12 +--- bayeux/_src/mcmc/tfp.py | 129 +++++++++++++++++++++++++++++++++++ bayeux/_src/shared.py | 8 +++ bayeux/mcmc/__init__.py | 4 +- bayeux/tests/mcmc_test.py | 14 ++++ 9 files changed, 173 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36f8fed..bb786fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,11 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] +## [0.1.8] - 2024-02-14 + +### Add HMC and NUTS from TFP +### Small change to blackjax default step size + ## [0.1.7] - 2024-02-13 ### Add SNAPER HMC from TFP diff --git a/bayeux/__init__.py b/bayeux/__init__.py index 026e7db..64c2d69 100644 --- a/bayeux/__init__.py +++ b/bayeux/__init__.py @@ -16,7 +16,7 @@ # A new PyPI release will be pushed everytime `__version__` is increased # When changing this, also update the CHANGELOG.md -__version__ = '0.1.7' +__version__ = '0.1.8' # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 diff --git a/bayeux/_src/mcmc/blackjax.py b/bayeux/_src/mcmc/blackjax.py index d02f34d..5221e7b 100644 --- a/bayeux/_src/mcmc/blackjax.py +++ b/bayeux/_src/mcmc/blackjax.py @@ -39,13 +39,14 @@ def get_extra_kwargs(kwargs): - return { + defaults = { "chain_method": "vectorized", "num_chains": 8, "num_draws": 500, "num_adapt_draws": 500, - "return_pytree": False, - } | kwargs + "return_pytree": False} + shared.update_with_kwargs(defaults, **kwargs) + return defaults class _BlackjaxSampler(shared.Base): @@ -299,7 +300,7 @@ def get_adaptation_kwargs(adaptation_algorithm, algorithm, log_density, kwargs): run_kwargs["optim"] = optax.adam(learning_rate=0.01) run_required.remove("optim") if "step_size" in run_required: - run_kwargs["step_size"] = 0.001 + run_kwargs["step_size"] = 0.5 run_required.remove("step_size") run_kwargs["num_steps"] = kwargs.get("num_adapt_draws", run_kwargs["num_steps"]) @@ -315,12 +316,8 @@ def get_algorithm_kwargs(algorithm, log_density, kwargs): "step_size": 0.5, "num_integration_steps": 16, } | kwargs - algorithm_kwargs.update( - { - k: kwargs_with_defaults[k] - for k in algorithm_required - if k in kwargs_with_defaults - }) + shared.update_with_kwargs( + algorithm_kwargs, algorithm_required, **kwargs_with_defaults) algorithm_required.remove("logdensity_fn") algorithm_required.discard("inverse_mass_matrix") algorithm_required.discard("alpha") @@ -332,11 +329,6 @@ def get_algorithm_kwargs(algorithm, log_density, kwargs): raise ValueError(f"Unexpected required arguments: " f"{','.join(algorithm_required)}. Probably file a bug, but" " you can try to manually supply them as keywords.") - algorithm_kwargs.update( - { - k: kwargs_with_defaults[k] - for k in algorithm_kwargs - if k in kwargs_with_defaults}) return algorithm_kwargs diff --git a/bayeux/_src/mcmc/flowmc.py b/bayeux/_src/mcmc/flowmc.py index b5376d5..8381b25 100644 --- a/bayeux/_src/mcmc/flowmc.py +++ b/bayeux/_src/mcmc/flowmc.py @@ -62,8 +62,7 @@ def get_nf_model_kwargs(nf_model, n_features, kwargs): nf_model_kwargs, nf_model_required = shared.get_default_signature( nf_model) - nf_model_kwargs.update( - {k: defaults[k] for k in nf_model_required if k in defaults}) + shared.update_with_kwargs(nf_model_kwargs, nf_model_required, **defaults) nf_model_required.remove("key") nf_model_required.remove("kwargs") nf_model_required = nf_model_required - nf_model_kwargs.keys() @@ -74,9 +73,6 @@ def get_nf_model_kwargs(nf_model, n_features, kwargs): f"{','.join(nf_model_required)}. Probably file a bug, but " "you can try to manually supply them as keywords." ) - nf_model_kwargs.update( - {k: defaults[k] for k in nf_model_kwargs if k in defaults}) - return nf_model_kwargs @@ -94,9 +90,8 @@ def get_local_sampler_kwargs(local_sampler, log_density, n_features, kwargs): sampler_kwargs, sampler_required = shared.get_default_signature( local_sampler) + shared.update_with_kwargs(sampler_kwargs, sampler_required, **defaults) sampler_kwargs.setdefault("jit", True) - sampler_kwargs.update( - {k: defaults[k] for k in sampler_required if k in defaults}) sampler_required = sampler_required - sampler_kwargs.keys() if "params" in sampler_required: sampler_kwargs["params"] = defaults @@ -143,8 +138,7 @@ def get_sampler_kwargs(sampler, n_features, kwargs): "n_dim": n_features, "data": {}} | kwargs sampler_kwargs, sampler_required = shared.get_default_signature(sampler) - sampler_kwargs.update( - {k: defaults[k] for k in sampler_required if k in defaults}) + shared.update_with_kwargs(sampler_kwargs, sampler_required, **defaults) sampler_required = (sampler_required - {"nf_model", "local_sampler", "rng_key_set", "kwargs"}) sampler_required = sampler_required - sampler_kwargs.keys() @@ -155,7 +149,7 @@ def get_sampler_kwargs(sampler, n_features, kwargs): f"{','.join(sampler_required)}. Probably file a bug, but " "you can try to manually supply them as keywords." ) - return defaults | sampler_kwargs + return sampler_kwargs class _FlowMCSampler(shared.Base): diff --git a/bayeux/_src/mcmc/numpyro.py b/bayeux/_src/mcmc/numpyro.py index 0b174d0..11cad3a 100644 --- a/bayeux/_src/mcmc/numpyro.py +++ b/bayeux/_src/mcmc/numpyro.py @@ -75,7 +75,7 @@ class NUTS(_NumpyroSampler): def get_sampler_kwargs(algorithm, kwargs): """Construct default args and include user arguments for samplers.""" sampler_kwargs, sampler_required = shared.get_default_signature(algorithm) - sampler_kwargs.update({k: kwargs[k] for k in sampler_required if k in kwargs}) + shared.update_with_kwargs(sampler_kwargs, sampler_required, **kwargs) sampler_kwargs.pop("potential_fn") sampler_required = sampler_required - sampler_kwargs.keys() @@ -84,7 +84,6 @@ def get_sampler_kwargs(algorithm, kwargs): raise ValueError(f"Unexpected required arguments: " f"{','.join(sampler_required)}. Probably file a bug, but " "you can try to manually supply them as keywords.") - sampler_kwargs.update({k: kwargs[k] for k in sampler_kwargs if k in kwargs}) return sampler_kwargs @@ -97,20 +96,13 @@ def get_mcmc_kwargs(kwargs): "num_chains": 8, "chain_method": "vectorized", } | kwargs - mcmc_kwargs.update( - {k: kwargs_with_defaults[k] for k in mcmc_required - if k in kwargs_with_defaults}) - + shared.update_with_kwargs(mcmc_kwargs, mcmc_required, **kwargs_with_defaults) mcmc_required = mcmc_required - mcmc_kwargs.keys() mcmc_required.remove("sampler") - if mcmc_required: raise ValueError(f"Unexpected required arguments: " f"{','.join(mcmc_required)}. Probably file a bug, but " "you can try to manually supply them as keywords.") - mcmc_kwargs.update( - {k: kwargs_with_defaults[k] for k in mcmc_kwargs - if k in kwargs_with_defaults}) return mcmc_kwargs diff --git a/bayeux/_src/mcmc/tfp.py b/bayeux/_src/mcmc/tfp.py index f682a91..36507da 100644 --- a/bayeux/_src/mcmc/tfp.py +++ b/bayeux/_src/mcmc/tfp.py @@ -31,10 +31,22 @@ import arviz as az from bayeux._src import shared import jax +import jax.numpy as jnp import numpy as np import tensorflow_probability.substrates.jax as tfp +_ALGORITHMS = { + "hmc": tfp.mcmc.HamiltonianMonteCarlo, + "nuts": tfp.mcmc.NoUTurnSampler, +} + +_TRACE_FNS = { + "hmc": tfp.experimental.mcmc.windowed_sampling.default_hmc_trace_fn, + "nuts": tfp.experimental.mcmc.windowed_sampling.default_nuts_trace_fn, +} + + class SnaperHMC(shared.Base): """Implements SNAPER HMC [1] with step size adaptation. @@ -104,6 +116,123 @@ def tlp(*args, **kwargs): return az.from_dict(posterior=draws, sample_stats=_tfp_stats_to_dict(trace)) +class _TFPBase(shared.Base): + """Base class for TFP windowed samplers.""" + name: str = "" + algorithm: str = "" + + def get_kwargs(self, **kwargs): + if self.algorithm == "nuts": + target_accept_prob = 0.8 + else: + target_accept_prob = 0.6 + + kwargs = { + "target_accept_prob": target_accept_prob, + "num_adaptation_steps": 500, + "step_size": 0.5, + "num_leapfrog_steps": 8, + } | kwargs + extra_parameters = { + "num_draws": 1_000, + "num_chains": 8, + "num_adaptation_steps": 500, + "return_pytree": False, + } + shared.update_with_kwargs(extra_parameters, **kwargs) + + dual_averaging_kwargs, da_reqd = shared.get_default_signature( + tfp.mcmc.DualAveragingStepSizeAdaptation) + shared.update_with_kwargs(dual_averaging_kwargs, da_reqd, **kwargs) + da_reqd = da_reqd - dual_averaging_kwargs.keys() + da_reqd.remove("inner_kernel") + if da_reqd: + raise ValueError( + "Unexpected required arguments: " + f"{','.join(da_reqd)}. Probably file a bug, but " + "you can try to manually supply them as keywords." + ) + + proposal_kwargs, proposal_reqd = shared.get_default_signature( + _ALGORITHMS[self.algorithm]) + shared.update_with_kwargs(proposal_kwargs, proposal_reqd, **kwargs) + proposal_reqd = proposal_reqd - proposal_kwargs.keys() + proposal_reqd.remove("target_log_prob_fn") + if proposal_reqd: + raise ValueError( + "Unexpected required arguments: " + f"{','.join(proposal_reqd)}. Probably file a bug, but " + "you can try to manually supply them as keywords." + ) + return { + "extra_parameters": extra_parameters, + "dual_averaging_kwargs": dual_averaging_kwargs, + "proposal_kernel_kwargs": proposal_kwargs} + + def __call__(self, seed, **kwargs): + kwargs = self.get_kwargs(**kwargs) + init_key, sample_key = jax.random.split(seed) + + extra_parameters = kwargs["extra_parameters"] + dual_averaging_kwargs = kwargs["dual_averaging_kwargs"] + proposal_kernel_kwargs = kwargs["proposal_kernel_kwargs"] + initial_state = self.get_initial_state( + init_key, num_chains=extra_parameters["num_chains"]) + + vmapped_constrained_log_density = jax.vmap(self.constrained_log_density()) + initial_transformed_position, treedef = jax.tree_util.tree_flatten( + self.inverse_transform_fn(initial_state)) + + def target_log_prob_fn(*args): + return vmapped_constrained_log_density( + jax.tree_util.tree_unflatten(treedef, args)) + proposal_kernel_kwargs["target_log_prob_fn"] = target_log_prob_fn + + initial_running_variance = [ + tfp.experimental.stats.sample_stats.RunningVariance.from_stats( + num_samples=jnp.array(1, part.dtype), + mean=jnp.zeros_like(part), + variance=jnp.ones_like(part)) + for part in initial_transformed_position] + + draws, trace = tfp.experimental.mcmc.windowed_sampling._do_sampling( + kind=self.algorithm, + proposal_kernel_kwargs=proposal_kernel_kwargs, + dual_averaging_kwargs=dual_averaging_kwargs, + num_draws=extra_parameters["num_draws"], + num_burnin_steps=extra_parameters["num_adaptation_steps"], + initial_position=initial_transformed_position, + initial_running_variance=initial_running_variance, + bijector=None, + trace_fn=_TRACE_FNS[self.algorithm], + return_final_kernel_results=False, + chain_axis_names=None, + shard_axis_names=None, + seed=sample_key) + + draws = self.transform_fn(jax.tree_util.tree_unflatten(treedef, draws)) + if extra_parameters["return_pytree"]: + return draws + + if hasattr(draws, "_asdict"): + draws = draws._asdict() + elif not isinstance(draws, dict): + draws = {"var0": draws} + + draws = {x: np.swapaxes(v, 0, 1) for x, v in draws.items()} + return az.from_dict(posterior=draws, sample_stats=_tfp_stats_to_dict(trace)) + + +class NUTS(_TFPBase): + name = "tfp_nuts" + algorithm = "nuts" + + +class HMC(_TFPBase): + name = "tfp_hmc" + algorithm = "hmc" + + def _tfp_stats_to_dict(stats): new_stats = {} for k, v in stats.items(): diff --git a/bayeux/_src/shared.py b/bayeux/_src/shared.py index 3c8673d..d8cd9fe 100644 --- a/bayeux/_src/shared.py +++ b/bayeux/_src/shared.py @@ -37,6 +37,14 @@ def map_fn(chain_method, fn): raise ValueError(f"Chain method {chain_method} not supported.") +def update_with_kwargs(defaults, reqd=None, **kwargs): + if reqd is None: + reqd = set() + defaults.update( + (k, kwargs[k]) for k in (defaults.keys() | reqd) & kwargs.keys()) + return defaults + + def _default_init( *, initial_state, diff --git a/bayeux/mcmc/__init__.py b/bayeux/mcmc/__init__.py index a6978f6..9c8ed81 100644 --- a/bayeux/mcmc/__init__.py +++ b/bayeux/mcmc/__init__.py @@ -19,8 +19,10 @@ import importlib # TFP-on-JAX always installed +from bayeux._src.mcmc.tfp import HMC as HMC_TFP +from bayeux._src.mcmc.tfp import NUTS as NUTS_TFP from bayeux._src.mcmc.tfp import SnaperHMC as SNAPER_HMC_TFP -__all__ = ["SNAPER_HMC_TFP"] +__all__ = ["HMC_TFP", "NUTS_TFP", "SNAPER_HMC_TFP"] if importlib.util.find_spec("blackjax") is not None: from bayeux._src.mcmc.blackjax import CheesHMC as CheesHMCblackjax diff --git a/bayeux/tests/mcmc_test.py b/bayeux/tests/mcmc_test.py index 48e46d5..7fbb4fb 100644 --- a/bayeux/tests/mcmc_test.py +++ b/bayeux/tests/mcmc_test.py @@ -79,6 +79,20 @@ def test_return_pytree_tfp(): assert pytree["x"]["y"].shape == (10, 4) +def test_return_pytree_tfp_nuts(): + model = bx.Model(log_density=lambda pt: -pt["x"]["y"]**2, + test_point={"x": {"y": jnp.array(1.)}}) + seed = jax.random.PRNGKey(0) + pytree = model.mcmc.tfp_nuts( + seed=seed, + return_pytree=True, + num_chains=4, + num_draws=10, + num_adaptation_steps=10, + ) + assert pytree["x"]["y"].shape == (10, 4) + + @pytest.mark.skipif(importlib.util.find_spec("flowMC") is None, reason="Test requires flowMC which is not installed") def test_return_pytree_flowmc():