From 7592ceea056f4a6e107b0e08c157043018d82087 Mon Sep 17 00:00:00 2001 From: xuewc Date: Thu, 29 Aug 2024 22:22:25 +0800 Subject: [PATCH] option to control number of parallel processes --- src/elisa/infer/helper.py | 38 ++++++++++++++++++++++++++++++-------- src/elisa/infer/results.py | 33 +++++++++++++++++++++++++-------- src/elisa/util/misc.py | 4 +++- 3 files changed, 58 insertions(+), 17 deletions(-) diff --git a/src/elisa/infer/helper.py b/src/elisa/infer/helper.py index 8187cddc..58e3d499 100644 --- a/src/elisa/infer/helper.py +++ b/src/elisa/infer/helper.py @@ -24,7 +24,11 @@ pstat, wstat, ) -from elisa.util.misc import get_unit_latex, progress_bar_factory +from elisa.util.misc import ( + get_parallel_number, + get_unit_latex, + progress_bar_factory, +) if TYPE_CHECKING: from collections.abc import Sequence @@ -697,22 +701,23 @@ def sim_parallel_fit( run_str: str, progress: bool, update_rate: int, + n_parallel: int, ) -> dict: """Fit simulation data in parallel.""" n = len(result['valid']) - cores = jax.local_device_count() - batch = n // cores + n_parallel = int(n_parallel) + batch = n // n_parallel if progress: pbar_factory = progress_bar_factory( - n, cores, run_str=run_str, update_rate=update_rate + n, n_parallel, run_str=run_str, update_rate=update_rate ) fn = pbar_factory(fit_once) else: fn = fit_once fit_pmap = jax.pmap(lambda *args: lax.fori_loop(0, batch, fn, args)[1]) - reshape = lambda x: x.reshape((cores, -1) + x.shape[1:]) + reshape = lambda x: x.reshape((n_parallel, -1) + x.shape[1:]) result = fit_pmap( jax.tree_map(reshape, sim_data), jax.tree_map(reshape, result), @@ -727,6 +732,7 @@ def simulate_and_fit( model_values: dict[str, JAXArray], n: int = 1, parallel: bool = True, + n_parallel: int | None = None, progress: bool = True, update_rate: int = 50, run_str: str = 'Fitting', @@ -745,6 +751,9 @@ def simulate_and_fit( The number of simulations of each model value, by default 1. parallel : bool, optional Whether to fit in parallel, by default True. + n_parallel : int, optional + The number of parallel processes when `parallel` is ``True``. + Defaults to ``jax.local_device_count()``. progress : bool, optional Whether to show progress bar, by default True. update_rate : int, optional @@ -764,6 +773,7 @@ def simulate_and_fit( f'{k}_model': model_values[f'{k}_model'] for k in simulators } n = int(n) + n_parallel = get_parallel_number(n_parallel) assert set(free_params) == set(free_names) assert n > 0 @@ -809,8 +819,20 @@ def simulate_and_fit( } # fit simulation data - fit_fn = sim_parallel_fit if parallel else sim_sequence_fit - result = fit_fn(sim_data, result, init, run_str, progress, update_rate) + if parallel: + result = sim_parallel_fit( + sim_data, + result, + init, + run_str, + progress, + update_rate, + n_parallel, + ) + else: + result = sim_sequence_fit( + sim_data, result, init, run_str, progress, update_rate + ) result['data'] = sim_data return result @@ -991,6 +1013,6 @@ class Helper(NamedTuple): """Function to simulate data.""" simulate_and_fit: Callable[ - [int, dict, dict, int, bool, bool, int, str], dict + [int, dict, dict, int, bool, int, bool, int, str], dict ] """Function to simulate data and then fit the simulation data.""" diff --git a/src/elisa/infer/results.py b/src/elisa/infer/results.py index e584c552..12b4ecc0 100644 --- a/src/elisa/infer/results.py +++ b/src/elisa/infer/results.py @@ -27,7 +27,7 @@ from elisa.infer.helper import check_params from elisa.infer.nested_sampling import NestedSampler from elisa.plot.plotter import MLEResultPlotter, PosteriorResultPlotter -from elisa.util.misc import make_pretty_table +from elisa.util.misc import get_parallel_number, make_pretty_table if TYPE_CHECKING: from collections.abc import Callable, Iterable, Sequence @@ -644,6 +644,7 @@ def boot( n: int = 10000, seed: int | None = None, parallel: bool = True, + n_parallel: int | None = None, progress: bool = True, update_rate: int = 50, ): @@ -658,15 +659,18 @@ def boot( The seed of random number generator used in parametric bootstrap. parallel : bool, optional Whether to run simulation fit in parallel. The default is True. + n_parallel : int, optional + Number of parallel processes to use when `parallel` is ``True``. + Defaults to ``jax.local_device_count()``. progress : bool, optional Whether to display progress bar. The default is True. update_rate : int, optional The update rate of progress bar. The default is 50. """ n = int(n) - n_core = jax.local_device_count() - if parallel and (n % n_core): - n += n_core - n % n_core + n_parallel = get_parallel_number(n_parallel) + if parallel and (n % n_parallel): + n += n_parallel - n % n_parallel # reuse the previous result if all setup is the same if self._boot and self._boot.n == n and self._boot.seed == seed: @@ -684,6 +688,7 @@ def boot( models, n, parallel, + n_parallel, progress, update_rate, 'Bootstrap', @@ -1401,6 +1406,7 @@ def ppc( n: int = 10000, seed: int | None = None, parallel: bool = True, + n_parallel: int | None = None, progress: bool = True, update_rate: int = 50, ): @@ -1414,15 +1420,18 @@ def ppc( The seed of random number generator used in posterior predictions. parallel : bool, optional Whether to run simulation fit in parallel. The default is True. + n_parallel : int, optional + Number of parallel processes to use when `parallel` is ``True``. + Defaults to ``jax.local_device_count()``. progress : bool, optional Whether to display progress bar. The default is True. update_rate : int, optional The update rate of progress bar. The default is 50. """ n = int(n) - n_core = jax.local_device_count() - if parallel and (n % n_core): - n += n_core - n % n_core + n_parallel = get_parallel_number(n_parallel) + if parallel and (n % n_parallel): + n += n_parallel - n % n_parallel # reuse the previous result if all setup is the same if self._ppc and self._ppc.n == n and self._ppc.seed == seed: @@ -1448,7 +1457,15 @@ def ppc( # perform ppc result = helper.simulate_and_fit( - seed, params, models, 1, parallel, progress, update_rate, 'PPC' + seed, + params, + models, + 1, + parallel, + n_parallel, + progress, + update_rate, + 'PPC', ) valid = result.pop('valid') result = jax.tree_map(lambda x: x[valid], result) diff --git a/src/elisa/util/misc.py b/src/elisa/util/misc.py index 56c225d1..9a3f9dc0 100644 --- a/src/elisa/util/misc.py +++ b/src/elisa/util/misc.py @@ -264,7 +264,9 @@ def get_parallel_number(n: int | None) -> int: else: n = int(n) if n <= 0: - raise ValueError('`n` must be positive') + raise ValueError( + f'number of parallel processes must be positive, got {n}' + ) if n > n_max: warnings.warn(