Skip to content

Commit

Permalink
feat: run AIES in parallel (#106)
Browse files Browse the repository at this point in the history
Co-authored-by: S.-L. Xie <82627490+xiesl97@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 29, 2024
1 parent 4f50114 commit 299f51d
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 42 deletions.
70 changes: 56 additions & 14 deletions src/elisa/infer/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from elisa.infer.nested_sampling import NestedSampler, reparam_loglike
from elisa.infer.results import MLEResult, PosteriorResult
from elisa.models.model import Model, get_model_info
from elisa.util.misc import add_suffix, build_namespace, make_pretty_table
from elisa.util.misc import (
add_suffix,
build_namespace,
get_parallel_number,
make_pretty_table,
)

if TYPE_CHECKING:
from typing import Any, Callable, Literal
Expand Down Expand Up @@ -927,6 +932,7 @@ def aies(
chains: int | None = None,
init: dict[str, float] | None = None,
chain_method: str = 'vectorized',
n_parallel: int | None = None,
progress: bool = True,
moves: dict | None = None,
**aies_kwargs: dict,
Expand Down Expand Up @@ -955,7 +961,10 @@ def aies(
init : dict, optional
Initial parameter for sampler to start from.
chain_method : str, optional
The chain method passed to :class:`numpyro.inf.MCMC`.
Available options are ``'vectorized'`` and ``'parallel'``.
n_parallel : int, optional
Number of parallel chains to run when `chain_method` is
``"parallel"``. Defaults to ``jax.local_device_count()``.
progress : bool, optional
Whether to show progress bar during sampling. The default is True.
moves : dict, optional
Expand Down Expand Up @@ -1002,17 +1011,50 @@ def aies(
else:
aies_kwargs['moves'] = moves

sampler = MCMC(
AIES(**aies_kwargs),
num_warmup=warmup,
num_samples=steps,
num_chains=chains,
chain_method=chain_method,
progress_bar=progress,
)
if chain_method == 'parallel':
aies_kernel = AIES(**aies_kwargs)

def do_mcmc(rng_key):
mcmc = MCMC(
aies_kernel,
num_warmup=warmup,
num_samples=steps,
num_chains=chains,
chain_method='vectorized',
progress_bar=False,
)
mcmc.run(
rng_key,
init_params=init,
)
return mcmc.get_samples(group_by_chain=True)

sampler.run(
rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']),
init_params=init,
)
rng_keys = jax.random.split(
jax.random.PRNGKey(self._helper.seed['mcmc']),
get_parallel_number(n_parallel),
)
traces = jax.pmap(do_mcmc)(rng_keys)
trace = {k: np.concatenate(v) for k, v in traces.items()}

sampler = MCMC(
aies_kernel,
num_warmup=warmup,
num_samples=steps,
)
sampler._states = {sampler._sample_field: trace}

else:
sampler = MCMC(
AIES(**aies_kwargs),
num_warmup=warmup,
num_samples=steps,
num_chains=chains,
chain_method=chain_method,
progress_bar=progress,
)

sampler.run(
rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']),
init_params=init,
)
return PosteriorResult(sampler, self._helper, self)
35 changes: 35 additions & 0 deletions src/elisa/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import math
import re
import warnings
from functools import reduce
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -243,6 +244,40 @@ def fdjvp(primals, tangents):
return fn


def get_parallel_number(n: int | None) -> int:
"""Check and return the available parallel number in JAX.
Parameters
----------
n : int, optional
The desired number of parallel processes in JAX.
Returns
-------
int
The available number of parallel processes.
"""
n_max = jax.local_device_count()

if n is None:
return n_max
else:
n = int(n)
if n <= 0:
raise ValueError('`n` must be positive')

if n > n_max:
warnings.warn(
f'number of parallel processes ({n}) is more than the number of '
f'available devices ({jax.local_device_count()}), reset to '
f'{jax.local_device_count()}',
Warning,
)
n = jax.local_device_count()

return n


def get_unit_latex(unit: str, throw: bool = True) -> str:
"""Get latex string of a unit.
Expand Down
68 changes: 40 additions & 28 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,31 +96,43 @@ def test_trivial_bayes_fit():
seed=seed,
)

# Get Bayesian fit result, i.e. posterior
result = BayesFit(data, model).nuts()

# check the true parameters values are within the 90% CI
ci = result.ci(cl=0.9).intervals
assert ci['PowerLaw.K'][0] < K < ci['PowerLaw.K'][1]
assert ci['PowerLaw.alpha'][0] < alpha < ci['PowerLaw.alpha'][1]

# Check various methods of posterior result
assert result.ndata['total'] == nbins
assert result.dof == nbins - 2

result.ppc(1009)
result.flux(1, 2)
result.lumin(1, 10000, z=1)
result.eiso(1, 10000, z=1, duration=spec_exposure)
result.summary()
result.plot()
result.plot('data ne ene eene Fv vFv rq pit corner khat')
_ = result.deviance
_ = result.loo
_ = result.waic
assert all(i > 0.05 for i in result.gof.values())

plotter = result.plot
plotter.plot_qq('rd')
plotter.plot_qq('rp')
plotter.plot_qq('rq')
sampling_kwargs = {
'nuts': {},
'jaxns': {},
'aies': {
'chain_method': 'parallel',
'n_parallel': 4,
},
'ultranest': {},
'nautilus': {},
}

for method, kwargs in sampling_kwargs.items():
# Get Bayesian fit result, i.e. posterior
result = getattr(BayesFit(data, model), method)(**kwargs)

# check the true parameters values are within the 90% CI
ci = result.ci(cl=0.9).intervals
assert ci['PowerLaw.K'][0] < K < ci['PowerLaw.K'][1]
assert ci['PowerLaw.alpha'][0] < alpha < ci['PowerLaw.alpha'][1]

# Check various methods of posterior result
assert result.ndata['total'] == nbins
assert result.dof == nbins - 2

result.ppc(1009)
result.flux(1, 2)
result.lumin(1, 10000, z=1)
result.eiso(1, 10000, z=1, duration=spec_exposure)
result.summary()
result.plot()
result.plot('data ne ene eene Fv vFv rq pit corner khat')
_ = result.deviance
_ = result.loo
_ = result.waic
assert all(i > 0.05 for i in result.gof.values())

plotter = result.plot
plotter.plot_qq('rd')
plotter.plot_qq('rp')
plotter.plot_qq('rq')

0 comments on commit 299f51d

Please sign in to comment.