Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

option to control number of parallel processes #107

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions src/elisa/infer/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
pstat,
wstat,
)
from elisa.util.misc import get_unit_latex, progress_bar_factory
from elisa.util.misc import (
get_parallel_number,
get_unit_latex,
progress_bar_factory,
)

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -697,22 +701,23 @@ def sim_parallel_fit(
run_str: str,
progress: bool,
update_rate: int,
n_parallel: int,
) -> dict:
"""Fit simulation data in parallel."""
n = len(result['valid'])
cores = jax.local_device_count()
batch = n // cores
n_parallel = int(n_parallel)
batch = n // n_parallel

if progress:
pbar_factory = progress_bar_factory(
n, cores, run_str=run_str, update_rate=update_rate
n, n_parallel, run_str=run_str, update_rate=update_rate
)
fn = pbar_factory(fit_once)
else:
fn = fit_once

fit_pmap = jax.pmap(lambda *args: lax.fori_loop(0, batch, fn, args)[1])
reshape = lambda x: x.reshape((cores, -1) + x.shape[1:])
reshape = lambda x: x.reshape((n_parallel, -1) + x.shape[1:])
result = fit_pmap(
jax.tree_map(reshape, sim_data),
jax.tree_map(reshape, result),
Expand All @@ -727,6 +732,7 @@ def simulate_and_fit(
model_values: dict[str, JAXArray],
n: int = 1,
parallel: bool = True,
n_parallel: int | None = None,
progress: bool = True,
update_rate: int = 50,
run_str: str = 'Fitting',
Expand All @@ -745,6 +751,9 @@ def simulate_and_fit(
The number of simulations of each model value, by default 1.
parallel : bool, optional
Whether to fit in parallel, by default True.
n_parallel : int, optional
The number of parallel processes when `parallel` is ``True``.
Defaults to ``jax.local_device_count()``.
progress : bool, optional
Whether to show progress bar, by default True.
update_rate : int, optional
Expand All @@ -764,6 +773,7 @@ def simulate_and_fit(
f'{k}_model': model_values[f'{k}_model'] for k in simulators
}
n = int(n)
n_parallel = get_parallel_number(n_parallel)

assert set(free_params) == set(free_names)
assert n > 0
Expand Down Expand Up @@ -809,8 +819,20 @@ def simulate_and_fit(
}

# fit simulation data
fit_fn = sim_parallel_fit if parallel else sim_sequence_fit
result = fit_fn(sim_data, result, init, run_str, progress, update_rate)
if parallel:
result = sim_parallel_fit(
sim_data,
result,
init,
run_str,
progress,
update_rate,
n_parallel,
)
else:
result = sim_sequence_fit(
sim_data, result, init, run_str, progress, update_rate
)
result['data'] = sim_data
return result

Expand Down Expand Up @@ -991,6 +1013,6 @@ class Helper(NamedTuple):
"""Function to simulate data."""

simulate_and_fit: Callable[
[int, dict, dict, int, bool, bool, int, str], dict
[int, dict, dict, int, bool, int, bool, int, str], dict
]
"""Function to simulate data and then fit the simulation data."""
33 changes: 25 additions & 8 deletions src/elisa/infer/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from elisa.infer.helper import check_params
from elisa.infer.nested_sampling import NestedSampler
from elisa.plot.plotter import MLEResultPlotter, PosteriorResultPlotter
from elisa.util.misc import make_pretty_table
from elisa.util.misc import get_parallel_number, make_pretty_table

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Sequence
Expand Down Expand Up @@ -644,6 +644,7 @@ def boot(
n: int = 10000,
seed: int | None = None,
parallel: bool = True,
n_parallel: int | None = None,
progress: bool = True,
update_rate: int = 50,
):
Expand All @@ -658,15 +659,18 @@ def boot(
The seed of random number generator used in parametric bootstrap.
parallel : bool, optional
Whether to run simulation fit in parallel. The default is True.
n_parallel : int, optional
Number of parallel processes to use when `parallel` is ``True``.
Defaults to ``jax.local_device_count()``.
progress : bool, optional
Whether to display progress bar. The default is True.
update_rate : int, optional
The update rate of progress bar. The default is 50.
"""
n = int(n)
n_core = jax.local_device_count()
if parallel and (n % n_core):
n += n_core - n % n_core
n_parallel = get_parallel_number(n_parallel)
if parallel and (n % n_parallel):
n += n_parallel - n % n_parallel

# reuse the previous result if all setup is the same
if self._boot and self._boot.n == n and self._boot.seed == seed:
Expand All @@ -684,6 +688,7 @@ def boot(
models,
n,
parallel,
n_parallel,
progress,
update_rate,
'Bootstrap',
Expand Down Expand Up @@ -1401,6 +1406,7 @@ def ppc(
n: int = 10000,
seed: int | None = None,
parallel: bool = True,
n_parallel: int | None = None,
progress: bool = True,
update_rate: int = 50,
):
Expand All @@ -1414,15 +1420,18 @@ def ppc(
The seed of random number generator used in posterior predictions.
parallel : bool, optional
Whether to run simulation fit in parallel. The default is True.
n_parallel : int, optional
Number of parallel processes to use when `parallel` is ``True``.
Defaults to ``jax.local_device_count()``.
progress : bool, optional
Whether to display progress bar. The default is True.
update_rate : int, optional
The update rate of progress bar. The default is 50.
"""
n = int(n)
n_core = jax.local_device_count()
if parallel and (n % n_core):
n += n_core - n % n_core
n_parallel = get_parallel_number(n_parallel)
if parallel and (n % n_parallel):
n += n_parallel - n % n_parallel

# reuse the previous result if all setup is the same
if self._ppc and self._ppc.n == n and self._ppc.seed == seed:
Expand All @@ -1448,7 +1457,15 @@ def ppc(

# perform ppc
result = helper.simulate_and_fit(
seed, params, models, 1, parallel, progress, update_rate, 'PPC'
seed,
params,
models,
1,
parallel,
n_parallel,
progress,
update_rate,
'PPC',
)
valid = result.pop('valid')
result = jax.tree_map(lambda x: x[valid], result)
Expand Down
4 changes: 3 additions & 1 deletion src/elisa/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ def get_parallel_number(n: int | None) -> int:
else:
n = int(n)
if n <= 0:
raise ValueError('`n` must be positive')
raise ValueError(
f'number of parallel processes must be positive, got {n}'
)

if n > n_max:
warnings.warn(
Expand Down