From 65954f3755d9dc692f6c45611d03a86396da8a98 Mon Sep 17 00:00:00 2001 From: "S.-L. Xie" <82627490+xiesl97@users.noreply.github.com> Date: Wed, 28 Aug 2024 16:34:27 +0800 Subject: [PATCH 1/6] aies parallel setting --- src/elisa/infer/fit.py | 53 ++++++++++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index fb2e1a6..7b3740d 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -956,6 +956,8 @@ def aies( Initial parameter for sampler to start from. chain_method : str, optional The chain method passed to :class:`numpyro.inf.MCMC`. + Only 'vectorized' or 'parallel' could be set. + Make sure `n` host devices are same with chains, If set 'parallel'. progress : bool, optional Whether to show progress bar during sampling. The default is True. moves : dict, optional @@ -1002,17 +1004,44 @@ def aies( else: aies_kwargs['moves'] = moves - sampler = MCMC( - AIES(**aies_kwargs), - num_warmup=warmup, - num_samples=steps, - num_chains=chains, - chain_method=chain_method, - progress_bar=progress, - ) + if chain_method == 'parallel': + aies_kernel = AIES(**aies_kwargs) + def do_mcmc(rng_key): + mcmc = MCMC( + aies_kernel, + num_warmup=0, + num_samples=warmup, + num_chains=chains, + chain_method='vectorized', + progress_bar=False, + ) + mcmc.run(rng_key, init_params=init,) + return mcmc.get_samples(), mcmc.get_extra_fields() + # + rng_keys = jax.random.split(jax.random.PRNGKey(self._helper.seed['mcmc']), + chains) + traces = jax.pmap(do_mcmc)(rng_keys) + + sampler = MCMC( + aies_kernel, + num_warmup=warmup, + num_samples=steps, + ) + sampler._states = {sampler._sample_field: traces[0]} + return PosteriorResult(sampler, self._helper, self) - sampler.run( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - init_params=init, - ) + else: + sampler = MCMC( + AIES(**aies_kwargs), + num_warmup=warmup, + num_samples=steps, + num_chains=chains, + chain_method=chain_method, + progress_bar=progress, + ) + + sampler.run( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + init_params=init, + ) return PosteriorResult(sampler, self._helper, self) From ad0622cc07178eb4c3373187c14d3c0e10af7ab9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Aug 2024 08:36:42 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/elisa/infer/fit.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index 7b3740d..02d6cb8 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -1006,6 +1006,7 @@ def aies( if chain_method == 'parallel': aies_kernel = AIES(**aies_kwargs) + def do_mcmc(rng_key): mcmc = MCMC( aies_kernel, @@ -1015,11 +1016,16 @@ def do_mcmc(rng_key): chain_method='vectorized', progress_bar=False, ) - mcmc.run(rng_key, init_params=init,) + mcmc.run( + rng_key, + init_params=init, + ) return mcmc.get_samples(), mcmc.get_extra_fields() - # - rng_keys = jax.random.split(jax.random.PRNGKey(self._helper.seed['mcmc']), - chains) + + # + rng_keys = jax.random.split( + jax.random.PRNGKey(self._helper.seed['mcmc']), chains + ) traces = jax.pmap(do_mcmc)(rng_keys) sampler = MCMC( From 8aaa209224f2fbf4b6e450901c904753c43aecf9 Mon Sep 17 00:00:00 2001 From: "S.-L. Xie" <82627490+xiesl97@users.noreply.github.com> Date: Wed, 28 Aug 2024 17:03:53 +0800 Subject: [PATCH 3/6] fix aies parallel setting --- src/elisa/infer/fit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index 02d6cb8..17a5baf 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -1010,8 +1010,8 @@ def aies( def do_mcmc(rng_key): mcmc = MCMC( aies_kernel, - num_warmup=0, - num_samples=warmup, + num_warmup=warmup, + num_samples=steps, num_chains=chains, chain_method='vectorized', progress_bar=False, @@ -1020,7 +1020,7 @@ def do_mcmc(rng_key): rng_key, init_params=init, ) - return mcmc.get_samples(), mcmc.get_extra_fields() + return {**mcmc.get_samples()} # rng_keys = jax.random.split( @@ -1033,7 +1033,7 @@ def do_mcmc(rng_key): num_warmup=warmup, num_samples=steps, ) - sampler._states = {sampler._sample_field: traces[0]} + sampler._states = {sampler._sample_field: traces} return PosteriorResult(sampler, self._helper, self) else: From 9fed1177364a773e3d4521f216d83544442c3870 Mon Sep 17 00:00:00 2001 From: xuewc Date: Thu, 29 Aug 2024 18:51:05 +0800 Subject: [PATCH 4/6] add helper function --- src/elisa/util/misc.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/elisa/util/misc.py b/src/elisa/util/misc.py index 4318356..56c225d 100644 --- a/src/elisa/util/misc.py +++ b/src/elisa/util/misc.py @@ -4,6 +4,7 @@ import math import re +import warnings from functools import reduce from typing import TYPE_CHECKING @@ -243,6 +244,40 @@ def fdjvp(primals, tangents): return fn +def get_parallel_number(n: int | None) -> int: + """Check and return the available parallel number in JAX. + + Parameters + ---------- + n : int, optional + The desired number of parallel processes in JAX. + + Returns + ------- + int + The available number of parallel processes. + """ + n_max = jax.local_device_count() + + if n is None: + return n_max + else: + n = int(n) + if n <= 0: + raise ValueError('`n` must be positive') + + if n > n_max: + warnings.warn( + f'number of parallel processes ({n}) is more than the number of ' + f'available devices ({jax.local_device_count()}), reset to ' + f'{jax.local_device_count()}', + Warning, + ) + n = jax.local_device_count() + + return n + + def get_unit_latex(unit: str, throw: bool = True) -> str: """Get latex string of a unit. From 7138d29470ecf86dd7bd03410b7826298ca33820 Mon Sep 17 00:00:00 2001 From: xuewc Date: Thu, 29 Aug 2024 18:51:35 +0800 Subject: [PATCH 5/6] fix: parallel chain method --- src/elisa/infer/fit.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index 17a5baf..fbedd04 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -25,7 +25,12 @@ from elisa.infer.nested_sampling import NestedSampler, reparam_loglike from elisa.infer.results import MLEResult, PosteriorResult from elisa.models.model import Model, get_model_info -from elisa.util.misc import add_suffix, build_namespace, make_pretty_table +from elisa.util.misc import ( + add_suffix, + build_namespace, + get_parallel_number, + make_pretty_table, +) if TYPE_CHECKING: from typing import Any, Callable, Literal @@ -927,6 +932,7 @@ def aies( chains: int | None = None, init: dict[str, float] | None = None, chain_method: str = 'vectorized', + n_parallel: int | None = None, progress: bool = True, moves: dict | None = None, **aies_kwargs: dict, @@ -955,9 +961,10 @@ def aies( init : dict, optional Initial parameter for sampler to start from. chain_method : str, optional - The chain method passed to :class:`numpyro.inf.MCMC`. - Only 'vectorized' or 'parallel' could be set. - Make sure `n` host devices are same with chains, If set 'parallel'. + Available options are ``'vectorized'`` and ``'parallel'``. + n_parallel : int, optional + Number of parallel chains to run when `chain_method` is + ``"parallel"``. Defaults to ``jax.local_device_count()``. progress : bool, optional Whether to show progress bar during sampling. The default is True. moves : dict, optional @@ -1020,21 +1027,21 @@ def do_mcmc(rng_key): rng_key, init_params=init, ) - return {**mcmc.get_samples()} + return mcmc.get_samples(group_by_chain=True) - # rng_keys = jax.random.split( - jax.random.PRNGKey(self._helper.seed['mcmc']), chains + jax.random.PRNGKey(self._helper.seed['mcmc']), + get_parallel_number(n_parallel), ) traces = jax.pmap(do_mcmc)(rng_keys) + trace = {k: np.concatenate(v) for k, v in traces.items()} sampler = MCMC( aies_kernel, num_warmup=warmup, num_samples=steps, ) - sampler._states = {sampler._sample_field: traces} - return PosteriorResult(sampler, self._helper, self) + sampler._states = {sampler._sample_field: trace} else: sampler = MCMC( From 150c9db9fd249481e0e7a63adc649167b99759ee Mon Sep 17 00:00:00 2001 From: xuewc Date: Thu, 29 Aug 2024 19:02:27 +0800 Subject: [PATCH 6/6] test: more tests for MCMC --- tests/test_fit.py | 68 ++++++++++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/tests/test_fit.py b/tests/test_fit.py index e708abb..151aff9 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -96,31 +96,43 @@ def test_trivial_bayes_fit(): seed=seed, ) - # Get Bayesian fit result, i.e. posterior - result = BayesFit(data, model).nuts() - - # check the true parameters values are within the 90% CI - ci = result.ci(cl=0.9).intervals - assert ci['PowerLaw.K'][0] < K < ci['PowerLaw.K'][1] - assert ci['PowerLaw.alpha'][0] < alpha < ci['PowerLaw.alpha'][1] - - # Check various methods of posterior result - assert result.ndata['total'] == nbins - assert result.dof == nbins - 2 - - result.ppc(1009) - result.flux(1, 2) - result.lumin(1, 10000, z=1) - result.eiso(1, 10000, z=1, duration=spec_exposure) - result.summary() - result.plot() - result.plot('data ne ene eene Fv vFv rq pit corner khat') - _ = result.deviance - _ = result.loo - _ = result.waic - assert all(i > 0.05 for i in result.gof.values()) - - plotter = result.plot - plotter.plot_qq('rd') - plotter.plot_qq('rp') - plotter.plot_qq('rq') + sampling_kwargs = { + 'nuts': {}, + 'jaxns': {}, + 'aies': { + 'chain_method': 'parallel', + 'n_parallel': 4, + }, + 'ultranest': {}, + 'nautilus': {}, + } + + for method, kwargs in sampling_kwargs.items(): + # Get Bayesian fit result, i.e. posterior + result = getattr(BayesFit(data, model), method)(**kwargs) + + # check the true parameters values are within the 90% CI + ci = result.ci(cl=0.9).intervals + assert ci['PowerLaw.K'][0] < K < ci['PowerLaw.K'][1] + assert ci['PowerLaw.alpha'][0] < alpha < ci['PowerLaw.alpha'][1] + + # Check various methods of posterior result + assert result.ndata['total'] == nbins + assert result.dof == nbins - 2 + + result.ppc(1009) + result.flux(1, 2) + result.lumin(1, 10000, z=1) + result.eiso(1, 10000, z=1, duration=spec_exposure) + result.summary() + result.plot() + result.plot('data ne ene eene Fv vFv rq pit corner khat') + _ = result.deviance + _ = result.loo + _ = result.waic + assert all(i > 0.05 for i in result.gof.values()) + + plotter = result.plot + plotter.plot_qq('rd') + plotter.plot_qq('rp') + plotter.plot_qq('rq')