Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: run AIES in parallel #106

Merged
merged 6 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 56 additions & 14 deletions src/elisa/infer/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from elisa.infer.nested_sampling import NestedSampler, reparam_loglike
from elisa.infer.results import MLEResult, PosteriorResult
from elisa.models.model import Model, get_model_info
from elisa.util.misc import add_suffix, build_namespace, make_pretty_table
from elisa.util.misc import (
add_suffix,
build_namespace,
get_parallel_number,
make_pretty_table,
)

if TYPE_CHECKING:
from typing import Any, Callable, Literal
Expand Down Expand Up @@ -927,6 +932,7 @@ def aies(
chains: int | None = None,
init: dict[str, float] | None = None,
chain_method: str = 'vectorized',
n_parallel: int | None = None,
progress: bool = True,
moves: dict | None = None,
**aies_kwargs: dict,
Expand Down Expand Up @@ -955,7 +961,10 @@ def aies(
init : dict, optional
Initial parameter for sampler to start from.
chain_method : str, optional
The chain method passed to :class:`numpyro.inf.MCMC`.
Available options are ``'vectorized'`` and ``'parallel'``.
n_parallel : int, optional
Number of parallel chains to run when `chain_method` is
``"parallel"``. Defaults to ``jax.local_device_count()``.
progress : bool, optional
Whether to show progress bar during sampling. The default is True.
moves : dict, optional
Expand Down Expand Up @@ -1002,17 +1011,50 @@ def aies(
else:
aies_kwargs['moves'] = moves

sampler = MCMC(
AIES(**aies_kwargs),
num_warmup=warmup,
num_samples=steps,
num_chains=chains,
chain_method=chain_method,
progress_bar=progress,
)
if chain_method == 'parallel':
aies_kernel = AIES(**aies_kwargs)

def do_mcmc(rng_key):
mcmc = MCMC(
aies_kernel,
num_warmup=warmup,
num_samples=steps,
num_chains=chains,
chain_method='vectorized',
progress_bar=False,
)
mcmc.run(
rng_key,
init_params=init,
)
return mcmc.get_samples(group_by_chain=True)

sampler.run(
rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']),
init_params=init,
)
rng_keys = jax.random.split(
jax.random.PRNGKey(self._helper.seed['mcmc']),
get_parallel_number(n_parallel),
)
traces = jax.pmap(do_mcmc)(rng_keys)
trace = {k: np.concatenate(v) for k, v in traces.items()}

sampler = MCMC(
aies_kernel,
num_warmup=warmup,
num_samples=steps,
)
sampler._states = {sampler._sample_field: trace}

else:
sampler = MCMC(
AIES(**aies_kwargs),
num_warmup=warmup,
num_samples=steps,
num_chains=chains,
chain_method=chain_method,
progress_bar=progress,
)

sampler.run(
rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']),
init_params=init,
)
return PosteriorResult(sampler, self._helper, self)
35 changes: 35 additions & 0 deletions src/elisa/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import math
import re
import warnings
from functools import reduce
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -243,6 +244,40 @@ def fdjvp(primals, tangents):
return fn


def get_parallel_number(n: int | None) -> int:
"""Check and return the available parallel number in JAX.

Parameters
----------
n : int, optional
The desired number of parallel processes in JAX.

Returns
-------
int
The available number of parallel processes.
"""
n_max = jax.local_device_count()

if n is None:
return n_max
else:
n = int(n)
if n <= 0:
raise ValueError('`n` must be positive')

if n > n_max:
warnings.warn(
f'number of parallel processes ({n}) is more than the number of '
f'available devices ({jax.local_device_count()}), reset to '
f'{jax.local_device_count()}',
Warning,
)
n = jax.local_device_count()

return n


def get_unit_latex(unit: str, throw: bool = True) -> str:
"""Get latex string of a unit.

Expand Down
68 changes: 40 additions & 28 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,31 +96,43 @@ def test_trivial_bayes_fit():
seed=seed,
)

# Get Bayesian fit result, i.e. posterior
result = BayesFit(data, model).nuts()

# check the true parameters values are within the 90% CI
ci = result.ci(cl=0.9).intervals
assert ci['PowerLaw.K'][0] < K < ci['PowerLaw.K'][1]
assert ci['PowerLaw.alpha'][0] < alpha < ci['PowerLaw.alpha'][1]

# Check various methods of posterior result
assert result.ndata['total'] == nbins
assert result.dof == nbins - 2

result.ppc(1009)
result.flux(1, 2)
result.lumin(1, 10000, z=1)
result.eiso(1, 10000, z=1, duration=spec_exposure)
result.summary()
result.plot()
result.plot('data ne ene eene Fv vFv rq pit corner khat')
_ = result.deviance
_ = result.loo
_ = result.waic
assert all(i > 0.05 for i in result.gof.values())

plotter = result.plot
plotter.plot_qq('rd')
plotter.plot_qq('rp')
plotter.plot_qq('rq')
sampling_kwargs = {
'nuts': {},
'jaxns': {},
'aies': {
'chain_method': 'parallel',
'n_parallel': 4,
},
'ultranest': {},
'nautilus': {},
}

for method, kwargs in sampling_kwargs.items():
# Get Bayesian fit result, i.e. posterior
result = getattr(BayesFit(data, model), method)(**kwargs)

# check the true parameters values are within the 90% CI
ci = result.ci(cl=0.9).intervals
assert ci['PowerLaw.K'][0] < K < ci['PowerLaw.K'][1]
assert ci['PowerLaw.alpha'][0] < alpha < ci['PowerLaw.alpha'][1]

# Check various methods of posterior result
assert result.ndata['total'] == nbins
assert result.dof == nbins - 2

result.ppc(1009)
result.flux(1, 2)
result.lumin(1, 10000, z=1)
result.eiso(1, 10000, z=1, duration=spec_exposure)
result.summary()
result.plot()
result.plot('data ne ene eene Fv vFv rq pit corner khat')
_ = result.deviance
_ = result.loo
_ = result.waic
assert all(i > 0.05 for i in result.gof.values())

plotter = result.plot
plotter.plot_qq('rd')
plotter.plot_qq('rp')
plotter.plot_qq('rq')