Skip to content

Commit

Permalink
feat: add sampler ESS and SA (#120)
Browse files Browse the repository at this point in the history
* add `SA` sampler

* add `ESS` sampler

* test: add tests for `ess`

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: W.-C. Xue <58248583+wcxve@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 23, 2024
1 parent 3e8e49f commit e8a7c69
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 1 deletion.
211 changes: 210 additions & 1 deletion src/elisa/infer/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec
from numpyro.infer import AIES, MCMC, NUTS, init_to_value
from numpyro.infer import AIES, ESS, MCMC, NUTS, SA, init_to_value

from elisa.data.base import FixedData, ObservationData
from elisa.infer.helper import Helper, get_helper
Expand Down Expand Up @@ -1064,3 +1064,212 @@ def do_mcmc(rng_key):
init_params=init,
)
return PosteriorResult(sampler, self._helper, self)

def ess(
self,
warmup=2000,
steps=5000,
chains: int | None = None,
init: dict[str, float] | None = None,
chain_method: str = 'vectorized',
n_parallel: int | None = None,
progress: bool = True,
moves: dict | None = None,
**ess_kwargs: dict,
) -> PosteriorResult:
"""Ensemble Slice Sampling (ESS) of :mod:`numpyro`.
Ensemble Slice Sampling [1]_ is a gradient free method
that finds better slice sampling directions by sharing information
between chains. Suitable for low to moderate dimensional models.
Generally, num_chains should be at least twice the dimensionality of
the model.
.. note::
This kernel must be used with even `num_chains` > 1 and
``chain_method='vectorized'``.
Parameters
----------
warmup : int, optional
Number of warmup steps.
steps : int, optional
Number of steps to run for each chain.
chains : int, optional
Number of MCMC chains to run. Defaults to 4 * `D`, where `D` is
the dimension of model parameters.
init : dict, optional
Initial parameter for sampler to start from.
chain_method : str, optional
Available options are ``'vectorized'`` and ``'parallel'``.
Defaults to ``'vectorized'``.
n_parallel : int, optional
Number of parallel chains to run when `chain_method` is
``"parallel"``. Defaults to ``jax.local_device_count()``.
progress : bool, optional
Whether to show progress bar during sampling. The default is True.
If `chain_method` is set to ``'parallel'``, this is always False.
moves : dict, optional
Moves for the sampler.
**ess_kwargs : dict
Extra parameters passed to :class:`numpyro.infer.ESS`.
Returns
-------
PosteriorResult
The posterior sampling result.
References
----------
.. [1] zeus: a PYTHON implementation of ensemble slice sampling
for efficient Bayesian parameter inference
(https://academic.oup.com/mnras/article/508/3/3589/6381726),
Minas Karamanis, Florian Beutler, and John A. Peacock.
"""
if chains is None:
chains = 4 * len(self._helper.params_names['free'])
else:
chains = int(chains)

# TODO: option to let sampler starting from MLE
if init is None:
init = self._helper.free_default['constr_dic']
else:
init = self._helper.free_default['constr_dic'] | dict(init)
init = self._helper.constr_dic_to_unconstr_arr(init)
rng = np.random.default_rng(self._helper.seed['mcmc'])
jitter = 0.1 * np.abs(init)
low = init - jitter
high = init + jitter
init = rng.uniform(low, high, size=(chains, len(init)))
init = dict(zip(self._helper.params_names['free'], init.T))

ess_kwargs['model'] = self._helper.numpyro_model
if moves is None:
ess_kwargs['moves'] = {ESS.DifferentialMove(): 1.0}
else:
ess_kwargs['moves'] = moves

if chain_method == 'parallel':
ess_kernel = ESS(**ess_kwargs)

def do_mcmc(rng_key):
mcmc = MCMC(
ess_kernel,
num_warmup=warmup,
num_samples=steps,
num_chains=chains,
chain_method='vectorized',
progress_bar=False,
)
mcmc.run(
rng_key,
init_params=init,
)
return mcmc.get_samples(group_by_chain=True)

rng_keys = jax.random.split(
jax.random.PRNGKey(self._helper.seed['mcmc']),
get_parallel_number(n_parallel),
)
traces = jax.pmap(do_mcmc)(rng_keys)
trace = {k: np.concatenate(v) for k, v in traces.items()}

sampler = MCMC(
ess_kernel,
num_warmup=warmup,
num_samples=steps,
)
sampler._states = {sampler._sample_field: trace}

else:
sampler = MCMC(
ESS(**ess_kwargs),
num_warmup=warmup,
num_samples=steps,
num_chains=chains,
chain_method=chain_method,
progress_bar=progress,
)

sampler.run(
rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']),
init_params=init,
)
return PosteriorResult(sampler, self._helper, self)

def sa(
self,
warmup=20000,
steps=300000,
chains: int | None = None,
init: dict[str, float] | None = None,
chain_method: str = 'parallel',
progress: bool = True,
**sa_kwargs: dict,
) -> PosteriorResult:
"""Run the Sample Adaptive MCMC of :mod:`numpyro`.
Sample Adaptive MCMC, a gradient-free sampler. [1]_
Parameters
----------
warmup : int, optional
Number of warmup steps.
steps : int, optional
Number of steps to run for each chain.
chains : int, optional
Number of MCMC chains to run. If there are not enough devices
available, chains will run in sequence. Defaults to the number of
``jax.local_device_count()``.
init : dict, optional
Initial parameter for sampler to start from.
chain_method : str, optional
The chain method passed to :class:`numpyro.infer.MCMC`.
progress : bool, optional
Whether to show progress bar during sampling. The default is True.
**sa_kwargs : dict
Extra parameters passed to :class:`numpyro.infer.SA`.
Returns
-------
PosteriorResult
The posterior sampling result.
References
----------
.. [1] Sample Adaptive MCMC by Michael Zhu
<https://papers.nips.cc/paper/9107-sample-adaptive-mcmc>`__
"""
if chains is None:
chains = 4 * len(self._helper.params_names['free'])
else:
chains = int(chains)

# TODO: option to let sampler starting from MLE
if init is None:
init = self._helper.free_default['constr_dic']
else:
init = self._helper.free_default['constr_dic'] | dict(init)

default_sa_kwargs = {
'dense_mass': True,
'adapt_state_size': None,
}
sa_kwargs = default_sa_kwargs | sa_kwargs
sa_kwargs['model'] = self._helper.numpyro_model
sa_kwargs['init_strategy'] = init_to_value(values=init)

sampler = MCMC(
SA(**sa_kwargs),
num_warmup=warmup,
num_samples=steps,
num_chains=chains,
chain_method=chain_method,
progress_bar=progress,
)

sampler.run(
rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']),
)
return PosteriorResult(sampler, self._helper, self)
8 changes: 8 additions & 0 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ def test_trivial_bayes_fit():
'chain_method': 'parallel',
'n_parallel': 4,
},
'ess': {
'chain_method': 'parallel',
'n_parallel': 4,
},
# 'sa': {},
'ultranest': {},
'nautilus': {},
}
Expand All @@ -111,6 +116,9 @@ def test_trivial_bayes_fit():
# Get Bayesian fit result, i.e. posterior
result = getattr(BayesFit(data, model), method)(**kwargs)

# check convergence
assert all(i < 1.01 for i in result.rhat.values() if not np.isnan(i))

# check the true parameters values are within the 90% CI
ci = result.ci(cl=0.9).intervals
assert ci['PowerLaw.K'][0] < K < ci['PowerLaw.K'][1]
Expand Down

0 comments on commit e8a7c69

Please sign in to comment.